use super::AttentionError;
#[inline]
fn softmax_inplace(logits: &mut [f32]) {
if logits.is_empty() {
return;
}
let max_val = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0_f32;
for v in logits.iter_mut() {
*v = (*v - max_val).exp();
sum += *v;
}
if sum > 0.0 {
let inv = 1.0 / sum;
for v in logits.iter_mut() {
*v *= inv;
}
}
}
#[inline]
fn dot_f32(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
#[derive(Debug, Clone)]
pub struct MqaConfig {
pub num_heads: usize,
pub head_dim: usize,
pub scale: f32,
}
impl MqaConfig {
pub fn new(num_heads: usize, head_dim: usize) -> Self {
let scale = 1.0 / (head_dim as f32).sqrt();
Self { num_heads, head_dim, scale }
}
}
pub fn multi_query_attention(
query: &[f32],
key: &[f32],
value: &[f32],
mask: Option<&[f32]>,
config: &MqaConfig,
batch: usize,
seq_len: usize,
) -> Result<Vec<f32>, AttentionError> {
let h = config.num_heads;
let d = config.head_dim;
if h == 0 || d == 0 || batch == 0 || seq_len == 0 {
return Err(AttentionError::EmptyInput);
}
let expected_q = batch * seq_len * h * d;
let expected_kv = batch * seq_len * d;
if query.len() != expected_q {
return Err(AttentionError::QKShapeMismatch {
q: format!("query len={}, expected={}", query.len(), expected_q),
k: String::new(),
});
}
if key.len() != expected_kv || value.len() != expected_kv {
return Err(AttentionError::QKShapeMismatch {
q: format!("key/value len expected={}", expected_kv),
k: format!("key={}, value={}", key.len(), value.len()),
});
}
if let Some(m) = mask {
let expected_mask = batch * seq_len * seq_len;
if m.len() != expected_mask {
return Err(AttentionError::QKShapeMismatch {
q: format!("mask len={}", m.len()),
k: format!("expected {}", expected_mask),
});
}
}
let scale = config.scale;
let mut output = vec![0.0_f32; expected_q];
for b in 0..batch {
for head in 0..h {
for q_pos in 0..seq_len {
let q_offset = b * seq_len * h * d + q_pos * h * d + head * d;
let q_vec = &query[q_offset..q_offset + d];
let mut scores: Vec<f32> = (0..seq_len)
.map(|k_pos| {
let k_offset = b * seq_len * d + k_pos * d;
dot_f32(q_vec, &key[k_offset..k_offset + d]) * scale
})
.collect();
if let Some(m) = mask {
let mask_base = b * seq_len * seq_len + q_pos * seq_len;
for (k_pos, s) in scores.iter_mut().enumerate() {
*s += m[mask_base + k_pos];
}
}
softmax_inplace(&mut scores);
let out_offset = q_offset; for (k_pos, &w) in scores.iter().enumerate() {
let v_offset = b * seq_len * d + k_pos * d;
for dim in 0..d {
output[out_offset + dim] += w * value[v_offset + dim];
}
}
}
}
}
Ok(output)
}
#[derive(Debug, Clone)]
pub struct GqaConfig {
pub num_query_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub scale: f32,
}
impl GqaConfig {
pub fn new(
num_query_heads: usize,
num_kv_heads: usize,
head_dim: usize,
) -> Result<Self, AttentionError> {
if num_kv_heads == 0 || num_query_heads % num_kv_heads != 0 {
return Err(AttentionError::InvalidHeads {
dm: num_query_heads,
nh: num_kv_heads,
});
}
let scale = 1.0 / (head_dim as f32).sqrt();
Ok(Self { num_query_heads, num_kv_heads, head_dim, scale })
}
#[inline]
pub fn queries_per_kv(&self) -> usize {
self.num_query_heads / self.num_kv_heads
}
}
pub fn repeat_kv(
kv: &[f32],
n_rep: usize,
batch: usize,
seq_len: usize,
num_kv_heads: usize,
head_dim: usize,
) -> Vec<f32> {
if n_rep == 1 {
return kv.to_vec();
}
let num_q_heads = num_kv_heads * n_rep;
let mut out = vec![0.0_f32; batch * seq_len * num_q_heads * head_dim];
for b in 0..batch {
for s in 0..seq_len {
for kv_h in 0..num_kv_heads {
let src_offset = b * seq_len * num_kv_heads * head_dim
+ s * num_kv_heads * head_dim
+ kv_h * head_dim;
let src = &kv[src_offset..src_offset + head_dim];
for rep in 0..n_rep {
let q_h = kv_h * n_rep + rep;
let dst_offset = b * seq_len * num_q_heads * head_dim
+ s * num_q_heads * head_dim
+ q_h * head_dim;
out[dst_offset..dst_offset + head_dim].copy_from_slice(src);
}
}
}
}
out
}
pub fn grouped_query_attention(
query: &[f32],
key: &[f32],
value: &[f32],
mask: Option<&[f32]>,
config: &GqaConfig,
batch: usize,
seq_len: usize,
) -> Result<Vec<f32>, AttentionError> {
let nq = config.num_query_heads;
let nkv = config.num_kv_heads;
let d = config.head_dim;
let n_rep = config.queries_per_kv();
if nq == 0 || nkv == 0 || d == 0 || batch == 0 || seq_len == 0 {
return Err(AttentionError::EmptyInput);
}
let expected_q = batch * seq_len * nq * d;
let expected_kv = batch * seq_len * nkv * d;
if query.len() != expected_q {
return Err(AttentionError::QKShapeMismatch {
q: format!("query len={}, expected={}", query.len(), expected_q),
k: String::new(),
});
}
if key.len() != expected_kv {
return Err(AttentionError::QKShapeMismatch {
q: format!("key len={}, expected={}", key.len(), expected_kv),
k: String::new(),
});
}
if value.len() != expected_kv {
return Err(AttentionError::QKShapeMismatch {
q: format!("value len={}, expected={}", value.len(), expected_kv),
k: String::new(),
});
}
let key_expanded = repeat_kv(key, n_rep, batch, seq_len, nkv, d);
let val_expanded = repeat_kv(value, n_rep, batch, seq_len, nkv, d);
let scale = config.scale;
let mut output = vec![0.0_f32; expected_q];
for b in 0..batch {
for head in 0..nq {
for q_pos in 0..seq_len {
let q_offset = b * seq_len * nq * d + q_pos * nq * d + head * d;
let q_vec = &query[q_offset..q_offset + d];
let mut scores: Vec<f32> = (0..seq_len)
.map(|k_pos| {
let k_offset = b * seq_len * nq * d + k_pos * nq * d + head * d;
dot_f32(q_vec, &key_expanded[k_offset..k_offset + d]) * scale
})
.collect();
if let Some(m) = mask {
let mask_base = b * seq_len * seq_len + q_pos * seq_len;
for (k_pos, s) in scores.iter_mut().enumerate() {
*s += m[mask_base + k_pos];
}
}
softmax_inplace(&mut scores);
let out_offset = q_offset;
for (k_pos, &w) in scores.iter().enumerate() {
let v_offset = b * seq_len * nq * d + k_pos * nq * d + head * d;
for dim in 0..d {
output[out_offset + dim] += w * val_expanded[v_offset + dim];
}
}
}
}
}
Ok(output)
}
#[derive(Debug, Clone)]
pub struct AliBiConfig {
pub num_heads: usize,
pub max_seq_len: usize,
}
impl AliBiConfig {
pub fn new(num_heads: usize, max_seq_len: usize) -> Self {
Self { num_heads, max_seq_len }
}
}
pub fn alibi_slopes(num_heads: usize) -> Vec<f32> {
if num_heads == 0 {
return Vec::new();
}
let next_pow2 = num_heads.next_power_of_two();
let slopes_full: Vec<f32> = (1..=next_pow2)
.map(|h| 2.0_f32.powf(-(8.0 * h as f32 / next_pow2 as f32)))
.collect();
if next_pow2 == num_heads {
return slopes_full;
}
let half_pow2 = next_pow2 / 2;
let slopes_half: Vec<f32> = (1..=half_pow2)
.map(|h| 2.0_f32.powf(-(8.0 * h as f32 / half_pow2 as f32)))
.collect();
let mut result = slopes_half;
let mut full_iter = slopes_full.into_iter();
while result.len() < num_heads {
if let Some(s) = full_iter.next() {
result.push(s);
} else {
break;
}
}
result.truncate(num_heads);
result
}
pub fn alibi_bias(num_heads: usize, seq_len: usize) -> Vec<f32> {
let slopes = alibi_slopes(num_heads);
let mut out = vec![0.0_f32; num_heads * seq_len * seq_len];
for (h, &slope) in slopes.iter().enumerate() {
for i in 0..seq_len {
for j in 0..seq_len {
let dist = (i as isize - j as isize).unsigned_abs() as f32;
out[h * seq_len * seq_len + i * seq_len + j] = -slope * dist;
}
}
}
out
}
pub fn alibi_attention(
query: &[f32],
key: &[f32],
value: &[f32],
config: &AliBiConfig,
) -> Result<Vec<f32>, AttentionError> {
let h = config.num_heads;
if h == 0 {
return Err(AttentionError::EmptyInput);
}
let total = query.len();
if total == 0 || key.len() != total || value.len() != total {
return Err(AttentionError::QKShapeMismatch {
q: format!("query len={}", total),
k: format!("key={}, value={}", key.len(), value.len()),
});
}
if total % h != 0 {
return Err(AttentionError::InvalidHeads { dm: total, nh: h });
}
let seq_times_d = total / h;
let _ = seq_times_d; Err(AttentionError::QKShapeMismatch {
q: "alibi_attention requires AliBiFullConfig (use alibi_attention_full)".to_string(),
k: String::new(),
})
}
#[derive(Debug, Clone)]
pub struct AliBiFullConfig {
pub num_heads: usize,
pub head_dim: usize,
pub scale: f32,
}
impl AliBiFullConfig {
pub fn new(num_heads: usize, head_dim: usize) -> Self {
let scale = 1.0 / (head_dim as f32).sqrt();
Self { num_heads, head_dim, scale }
}
}
pub fn alibi_attention_full(
query: &[f32],
key: &[f32],
value: &[f32],
config: &AliBiFullConfig,
seq_len: usize,
) -> Result<Vec<f32>, AttentionError> {
let h = config.num_heads;
let d = config.head_dim;
let scale = config.scale;
if h == 0 || d == 0 || seq_len == 0 {
return Err(AttentionError::EmptyInput);
}
let expected = seq_len * h * d;
if query.len() != expected || key.len() != expected || value.len() != expected {
return Err(AttentionError::QKShapeMismatch {
q: format!("expected {} elements, got q={}", expected, query.len()),
k: format!("k={}, v={}", key.len(), value.len()),
});
}
let slopes = alibi_slopes(h);
let mut output = vec![0.0_f32; expected];
for head in 0..h {
let slope = slopes[head];
for q_pos in 0..seq_len {
let q_offset = q_pos * h * d + head * d;
let q_vec = &query[q_offset..q_offset + d];
let mut scores: Vec<f32> = (0..seq_len)
.map(|k_pos| {
let k_offset = k_pos * h * d + head * d;
let dot = dot_f32(q_vec, &key[k_offset..k_offset + d]);
let bias = -slope * (q_pos as isize - k_pos as isize).unsigned_abs() as f32;
dot * scale + bias
})
.collect();
softmax_inplace(&mut scores);
let out_offset = q_offset;
for (k_pos, &w) in scores.iter().enumerate() {
let v_offset = k_pos * h * d + head * d;
for dim in 0..d {
output[out_offset + dim] += w * value[v_offset + dim];
}
}
}
}
Ok(output)
}
#[derive(Debug, Clone)]
pub struct CrossAttentionConfig {
pub num_heads: usize,
pub head_dim: usize,
pub encoder_hidden_size: usize,
pub decoder_hidden_size: usize,
pub scale: f32,
}
impl CrossAttentionConfig {
pub fn new(
num_heads: usize,
head_dim: usize,
encoder_hidden_size: usize,
decoder_hidden_size: usize,
) -> Result<Self, AttentionError> {
let required = num_heads * head_dim;
if encoder_hidden_size != required {
return Err(AttentionError::InvalidHeads {
dm: encoder_hidden_size,
nh: num_heads,
});
}
if decoder_hidden_size != required {
return Err(AttentionError::InvalidHeads {
dm: decoder_hidden_size,
nh: num_heads,
});
}
let scale = 1.0 / (head_dim as f32).sqrt();
Ok(Self {
num_heads,
head_dim,
encoder_hidden_size,
decoder_hidden_size,
scale,
})
}
}
pub fn cross_attention(
query: &[f32],
key_value: &[f32],
mask: Option<&[f32]>,
config: &CrossAttentionConfig,
decoder_seq: usize,
encoder_seq: usize,
) -> Result<Vec<f32>, AttentionError> {
let h = config.num_heads;
let d = config.head_dim;
let scale = config.scale;
let hidden = h * d;
if h == 0 || d == 0 || decoder_seq == 0 || encoder_seq == 0 {
return Err(AttentionError::EmptyInput);
}
let expected_q = decoder_seq * hidden;
let expected_kv = encoder_seq * hidden;
if query.len() != expected_q {
return Err(AttentionError::QKShapeMismatch {
q: format!("query len={}, expected={}", query.len(), expected_q),
k: String::new(),
});
}
if key_value.len() != expected_kv {
return Err(AttentionError::QKShapeMismatch {
q: format!("key_value len={}, expected={}", key_value.len(), expected_kv),
k: String::new(),
});
}
if let Some(m) = mask {
let expected_mask = decoder_seq * encoder_seq;
if m.len() != expected_mask {
return Err(AttentionError::QKShapeMismatch {
q: format!("mask len={}, expected={}", m.len(), expected_mask),
k: String::new(),
});
}
}
let mut output = vec![0.0_f32; expected_q];
for head in 0..h {
for q_pos in 0..decoder_seq {
let q_offset = q_pos * hidden + head * d;
let q_vec = &query[q_offset..q_offset + d];
let mut scores: Vec<f32> = (0..encoder_seq)
.map(|k_pos| {
let k_offset = k_pos * hidden + head * d;
dot_f32(q_vec, &key_value[k_offset..k_offset + d]) * scale
})
.collect();
if let Some(m) = mask {
let mask_base = q_pos * encoder_seq;
for (k_pos, s) in scores.iter_mut().enumerate() {
*s += m[mask_base + k_pos];
}
}
softmax_inplace(&mut scores);
let out_offset = q_offset;
for (k_pos, &w) in scores.iter().enumerate() {
let v_offset = k_pos * hidden + head * d;
for dim in 0..d {
output[out_offset + dim] += w * key_value[v_offset + dim];
}
}
}
}
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_tensor(len: usize, seed: f32) -> Vec<f32> {
(0..len)
.map(|i| ((i as f32 * seed * 0.07 + seed).sin() * 0.3 + 0.1).abs())
.collect()
}
fn assert_close_f32(a: &[f32], b: &[f32], tol: f32, label: &str) {
assert_eq!(a.len(), b.len(), "{label}: length mismatch {} vs {}", a.len(), b.len());
for (i, (&x, &y)) in a.iter().zip(b.iter()).enumerate() {
assert!(
(x - y).abs() < tol,
"{label}[{i}]: {x} vs {y} (diff={})",
(x - y).abs()
);
}
}
#[test]
fn test_softmax_inplace_sums_to_one() {
let mut v = vec![1.0_f32, 2.0, 3.0, 4.0];
softmax_inplace(&mut v);
let sum: f32 = v.iter().sum();
assert!((sum - 1.0).abs() < 1e-6, "softmax should sum to 1, got {sum}");
}
#[test]
fn test_softmax_inplace_uniform_input() {
let mut v = vec![1.0_f32; 4];
softmax_inplace(&mut v);
for &x in &v {
assert!((x - 0.25).abs() < 1e-6, "uniform softmax should be 0.25, got {x}");
}
}
#[test]
fn test_softmax_inplace_single_element() {
let mut v = vec![42.0_f32];
softmax_inplace(&mut v);
assert!((v[0] - 1.0).abs() < 1e-6, "single element softmax = 1.0");
}
#[test]
fn test_alibi_slopes_power_of_two_length() {
let slopes = alibi_slopes(8);
assert_eq!(slopes.len(), 8);
for &s in &slopes {
assert!(s > 0.0 && s <= 1.0, "slope out of range: {s}");
}
for i in 1..slopes.len() {
assert!(
slopes[i] < slopes[i - 1],
"slopes should be decreasing: s[{}]={} >= s[{}]={}",
i, slopes[i], i - 1, slopes[i - 1]
);
}
}
#[test]
fn test_alibi_slopes_single_head() {
let slopes = alibi_slopes(1);
assert_eq!(slopes.len(), 1);
let expected = 2.0_f32.powf(-8.0);
assert!(
(slopes[0] - expected).abs() < 1e-6,
"single head slope: expected {expected}, got {}", slopes[0]
);
}
#[test]
fn test_alibi_slopes_returns_correct_count() {
for h in [1, 2, 3, 4, 5, 6, 7, 8, 12, 16] {
let slopes = alibi_slopes(h);
assert_eq!(slopes.len(), h, "slopes length should equal num_heads={h}");
}
}
#[test]
fn test_alibi_slopes_zero_heads() {
let slopes = alibi_slopes(0);
assert!(slopes.is_empty(), "zero heads → empty slopes");
}
#[test]
fn test_alibi_bias_shape() {
let h = 4;
let seq = 6;
let bias = alibi_bias(h, seq);
assert_eq!(bias.len(), h * seq * seq, "alibi_bias shape mismatch");
}
#[test]
fn test_alibi_bias_diagonal_is_zero() {
let h = 4;
let seq = 6;
let bias = alibi_bias(h, seq);
for head in 0..h {
for i in 0..seq {
let val = bias[head * seq * seq + i * seq + i];
assert!(
val.abs() < 1e-6,
"diagonal should be 0, got {val} at head={head} pos={i}"
);
}
}
}
#[test]
fn test_alibi_bias_negative_off_diagonal() {
let h = 4;
let seq = 6;
let bias = alibi_bias(h, seq);
for head in 0..h {
for i in 0..seq {
for j in 0..seq {
if i != j {
let val = bias[head * seq * seq + i * seq + j];
assert!(
val < 0.0,
"off-diagonal bias should be < 0, got {val} head={head} i={i} j={j}"
);
}
}
}
}
}
#[test]
fn test_alibi_bias_symmetric_in_distance() {
let h = 2;
let seq = 8;
let bias = alibi_bias(h, seq);
for head in 0..h {
for i in 0..seq {
for j in 0..seq {
let bij = bias[head * seq * seq + i * seq + j];
let bji = bias[head * seq * seq + j * seq + i];
assert!(
(bij - bji).abs() < 1e-6,
"bias not symmetric: [{head},{i},{j}]={bij} vs [{head},{j},{i}]={bji}"
);
}
}
}
}
#[test]
fn test_alibi_attention_output_shape() {
let h = 2;
let d = 4;
let seq = 6;
let config = AliBiFullConfig::new(h, d);
let total = seq * h * d;
let q = make_tensor(total, 1.0);
let k = make_tensor(total, 2.0);
let v = make_tensor(total, 3.0);
let out = alibi_attention_full(&q, &k, &v, &config, seq)
.expect("alibi_attention_full shape");
assert_eq!(out.len(), total, "output shape mismatch");
}
#[test]
fn test_alibi_attention_output_finite() {
let h = 4;
let d = 8;
let seq = 10;
let config = AliBiFullConfig::new(h, d);
let total = seq * h * d;
let q = make_tensor(total, 1.1);
let k = make_tensor(total, 2.2);
let v = make_tensor(total, 3.3);
let out = alibi_attention_full(&q, &k, &v, &config, seq)
.expect("alibi output finite");
for (i, &val) in out.iter().enumerate() {
assert!(val.is_finite(), "output[{i}] = {val} is not finite");
}
}
#[test]
fn test_alibi_attention_differs_from_standard_sdpa() {
let h = 2;
let d = 4;
let seq = 8;
let config = AliBiFullConfig::new(h, d);
let total = seq * h * d;
let q = make_tensor(total, 1.5);
let k = make_tensor(total, 2.5);
let v = make_tensor(total, 3.5);
let alibi_out = alibi_attention_full(&q, &k, &v, &config, seq)
.expect("alibi");
let scale = config.scale;
let mut std_out = vec![0.0_f32; total];
for head in 0..h {
for q_pos in 0..seq {
let q_off = q_pos * h * d + head * d;
let q_vec = &q[q_off..q_off + d];
let mut scores: Vec<f32> = (0..seq)
.map(|k_pos| {
let k_off = k_pos * h * d + head * d;
dot_f32(q_vec, &k[k_off..k_off + d]) * scale
})
.collect();
softmax_inplace(&mut scores);
for (k_pos, &w) in scores.iter().enumerate() {
let v_off = k_pos * h * d + head * d;
for dim in 0..d {
std_out[q_off + dim] += w * v[v_off + dim];
}
}
}
}
let diff: f32 = alibi_out.iter().zip(std_out.iter()).map(|(a, b)| (a - b).abs()).sum();
assert!(diff > 1e-4, "ALiBi should differ from standard SDPA, diff={diff}");
}
#[test]
fn test_alibi_attention_full_error_on_empty() {
let config = AliBiFullConfig::new(2, 4);
let result = alibi_attention_full(&[], &[], &[], &config, 0);
assert!(result.is_err(), "should error on empty input");
}
#[test]
fn test_mqa_output_shape() {
let h = 4;
let d = 8;
let batch = 2;
let seq = 5;
let config = MqaConfig::new(h, d);
let q = make_tensor(batch * seq * h * d, 1.0);
let k = make_tensor(batch * seq * d, 2.0);
let v = make_tensor(batch * seq * d, 3.0);
let out = multi_query_attention(&q, &k, &v, None, &config, batch, seq)
.expect("mqa shape");
assert_eq!(out.len(), batch * seq * h * d);
}
#[test]
fn test_mqa_single_head_equals_standard() {
let h = 1;
let d = 4;
let batch = 1;
let seq = 6;
let config = MqaConfig::new(h, d);
let total_q = batch * seq * h * d;
let total_kv = batch * seq * d;
let q = make_tensor(total_q, 1.2);
let k = make_tensor(total_kv, 2.3);
let v = make_tensor(total_kv, 3.4);
let out = multi_query_attention(&q, &k, &v, None, &config, batch, seq)
.expect("mqa single head");
assert_eq!(out.len(), total_q);
for &val in &out {
assert!(val.is_finite(), "output must be finite");
}
}
#[test]
fn test_mqa_output_all_finite() {
let h = 8;
let d = 16;
let batch = 1;
let seq = 12;
let config = MqaConfig::new(h, d);
let q = make_tensor(batch * seq * h * d, 1.7);
let k = make_tensor(batch * seq * d, 2.1);
let v = make_tensor(batch * seq * d, 0.9);
let out = multi_query_attention(&q, &k, &v, None, &config, batch, seq)
.expect("mqa finite");
for (i, &val) in out.iter().enumerate() {
assert!(val.is_finite(), "mqa out[{i}] = {val} not finite");
}
}
#[test]
fn test_mqa_error_on_wrong_query_shape() {
let config = MqaConfig::new(4, 8);
let q = vec![0.0_f32; 3]; let k = make_tensor(1 * 6 * 8, 1.0);
let v = make_tensor(1 * 6 * 8, 1.0);
let result = multi_query_attention(&q, &k, &v, None, &config, 1, 6);
assert!(result.is_err(), "should error on wrong query shape");
}
#[test]
fn test_mqa_mask_reduces_scores() {
let h = 1;
let d = 4;
let batch = 1;
let seq = 4;
let config = MqaConfig::new(h, d);
let q = vec![1.0_f32; batch * seq * h * d];
let k = vec![1.0_f32; batch * seq * d];
let mut v = vec![0.0_f32; batch * seq * d];
for dim in 0..d {
v[(seq - 1) * d + dim] = 1.0;
}
let mut mask = vec![f32::NEG_INFINITY; batch * seq * seq];
for q_pos in 0..seq {
mask[q_pos * seq + (seq - 1)] = 0.0; }
let out = multi_query_attention(&q, &k, &v, Some(&mask), &config, batch, seq)
.expect("mqa mask");
for (i, &val) in out.iter().enumerate() {
assert!(
(val - 1.0).abs() < 1e-4,
"mqa masked out[{i}] expected 1.0, got {val}"
);
}
}
#[test]
fn test_repeat_kv_output_shape() {
let batch = 2;
let seq = 6;
let nkv = 2;
let d = 4;
let n_rep = 3;
let kv = make_tensor(batch * seq * nkv * d, 1.0);
let out = repeat_kv(&kv, n_rep, batch, seq, nkv, d);
assert_eq!(out.len(), batch * seq * nkv * n_rep * d);
}
#[test]
fn test_repeat_kv_n_rep_one_is_identity() {
let batch = 1;
let seq = 4;
let nkv = 3;
let d = 4;
let kv = make_tensor(batch * seq * nkv * d, 2.5);
let out = repeat_kv(&kv, 1, batch, seq, nkv, d);
assert_eq!(out, kv, "n_rep=1 should be identity");
}
#[test]
fn test_repeat_kv_values_are_correct() {
let batch = 1;
let seq = 2;
let nkv = 2;
let d = 2;
let n_rep = 2;
let kv = vec![10.0_f32, 11.0, 20.0, 21.0, 30.0, 31.0, 40.0, 41.0];
let out = repeat_kv(&kv, n_rep, batch, seq, nkv, d);
assert_eq!(out.len(), batch * seq * nkv * n_rep * d);
assert_eq!(&out[0..2], &[10.0, 11.0], "s=0 kv_h=0 rep=0");
assert_eq!(&out[2..4], &[10.0, 11.0], "s=0 kv_h=0 rep=1");
assert_eq!(&out[4..6], &[20.0, 21.0], "s=0 kv_h=1 rep=0");
assert_eq!(&out[6..8], &[20.0, 21.0], "s=0 kv_h=1 rep=1");
}
#[test]
fn test_gqa_config_invalid_heads() {
let result = GqaConfig::new(7, 3, 8);
assert!(result.is_err(), "7 q heads / 3 kv heads should be invalid");
}
#[test]
fn test_gqa_config_valid() {
let config = GqaConfig::new(8, 2, 4).expect("valid gqa config");
assert_eq!(config.queries_per_kv(), 4);
}
#[test]
fn test_gqa_output_shape() {
let nq = 8;
let nkv = 2;
let d = 4;
let batch = 1;
let seq = 6;
let config = GqaConfig::new(nq, nkv, d).expect("gqa config");
let q = make_tensor(batch * seq * nq * d, 1.0);
let k = make_tensor(batch * seq * nkv * d, 2.0);
let v = make_tensor(batch * seq * nkv * d, 3.0);
let out = grouped_query_attention(&q, &k, &v, None, &config, batch, seq)
.expect("gqa shape");
assert_eq!(out.len(), batch * seq * nq * d);
}
#[test]
fn test_gqa_all_outputs_finite() {
let nq = 4;
let nkv = 2;
let d = 8;
let batch = 2;
let seq = 10;
let config = GqaConfig::new(nq, nkv, d).expect("gqa config");
let q = make_tensor(batch * seq * nq * d, 1.3);
let k = make_tensor(batch * seq * nkv * d, 0.7);
let v = make_tensor(batch * seq * nkv * d, 1.9);
let out = grouped_query_attention(&q, &k, &v, None, &config, batch, seq)
.expect("gqa finite");
for (i, &val) in out.iter().enumerate() {
assert!(val.is_finite(), "gqa out[{i}] = {val} not finite");
}
}
#[test]
fn test_gqa_mqa_equivalence() {
let nq = 4;
let d = 4;
let batch = 1;
let seq = 5;
let gqa_config = GqaConfig::new(nq, 1, d).expect("gqa config");
let mqa_config = MqaConfig { num_heads: nq, head_dim: d, scale: gqa_config.scale };
let q = make_tensor(batch * seq * nq * d, 1.0);
let k = make_tensor(batch * seq * d, 2.0);
let v = make_tensor(batch * seq * d, 3.0);
let mqa_out = multi_query_attention(&q, &k, &v, None, &mqa_config, batch, seq)
.expect("mqa");
let gqa_out = grouped_query_attention(&q, &k, &v, None, &gqa_config, batch, seq)
.expect("gqa");
assert_close_f32(&mqa_out, &gqa_out, 1e-5, "GQA(nkv=1) should equal MQA");
}
#[test]
fn test_gqa_mha_equivalence() {
let nq = 4;
let d = 4;
let batch = 1;
let seq = 5;
let gqa_config = GqaConfig::new(nq, nq, d).expect("gqa mha config");
let q = make_tensor(batch * seq * nq * d, 1.0);
let k = make_tensor(batch * seq * nq * d, 2.0);
let v = make_tensor(batch * seq * nq * d, 3.0);
let out = grouped_query_attention(&q, &k, &v, None, &gqa_config, batch, seq)
.expect("gqa mha");
assert_eq!(out.len(), batch * seq * nq * d);
for &val in &out {
assert!(val.is_finite());
}
}
#[test]
fn test_cross_attention_output_shape() {
let h = 4;
let d = 8;
let decoder_seq = 6;
let encoder_seq = 10;
let hidden = h * d;
let config = CrossAttentionConfig::new(h, d, hidden, hidden)
.expect("cross attn config");
let q = make_tensor(decoder_seq * hidden, 1.0);
let kv = make_tensor(encoder_seq * hidden, 2.0);
let out = cross_attention(&q, &kv, None, &config, decoder_seq, encoder_seq)
.expect("cross attn shape");
assert_eq!(out.len(), decoder_seq * hidden);
}
#[test]
fn test_cross_attention_all_finite() {
let h = 2;
let d = 4;
let decoder_seq = 8;
let encoder_seq = 12;
let hidden = h * d;
let config = CrossAttentionConfig::new(h, d, hidden, hidden)
.expect("cross attn config");
let q = make_tensor(decoder_seq * hidden, 1.5);
let kv = make_tensor(encoder_seq * hidden, 2.5);
let out = cross_attention(&q, &kv, None, &config, decoder_seq, encoder_seq)
.expect("cross attn finite");
for (i, &val) in out.iter().enumerate() {
assert!(val.is_finite(), "cross attn out[{i}] = {val} not finite");
}
}
#[test]
fn test_cross_attention_encoder_mask() {
let h = 1;
let d = 4;
let decoder_seq = 3;
let encoder_seq = 4;
let hidden = h * d;
let config = CrossAttentionConfig::new(h, d, hidden, hidden)
.expect("cross config");
let q = vec![1.0_f32; decoder_seq * hidden];
let k = vec![1.0_f32; encoder_seq * hidden];
let mut v_flat = vec![0.0_f32; encoder_seq * hidden];
for dim in 0..d {
v_flat[2 * hidden + head_dim_at(0, d, dim)] = 2.0;
}
let mut kv_tensor = vec![0.0_f32; encoder_seq * hidden];
kv_tensor.copy_from_slice(&k);
let mut mask = vec![f32::NEG_INFINITY; decoder_seq * encoder_seq];
for q_pos in 0..decoder_seq {
mask[q_pos * encoder_seq + 2] = 0.0;
}
let mut kv = vec![0.0_f32; encoder_seq * hidden];
for dim in 0..d {
kv[2 * hidden + dim] = 2.0; }
let out = cross_attention(&q, &kv, Some(&mask), &config, decoder_seq, encoder_seq)
.expect("cross mask");
for (i, &val) in out.iter().enumerate() {
assert!(
(val - 2.0).abs() < 1e-4,
"cross attn masked out[{i}] expected 2.0, got {val}"
);
}
}
fn head_dim_at(head: usize, d: usize, dim: usize) -> usize {
head * d + dim
}
#[test]
fn test_cross_attention_config_validation() {
let result = CrossAttentionConfig::new(4, 8, 24, 32); assert!(result.is_err(), "mismatched encoder_hidden_size should fail");
}
#[test]
fn test_cross_attention_decoder_longer_than_encoder() {
let h = 2;
let d = 4;
let decoder_seq = 20;
let encoder_seq = 5;
let hidden = h * d;
let config = CrossAttentionConfig::new(h, d, hidden, hidden)
.expect("cross config");
let q = make_tensor(decoder_seq * hidden, 1.0);
let kv = make_tensor(encoder_seq * hidden, 2.0);
let out = cross_attention(&q, &kv, None, &config, decoder_seq, encoder_seq)
.expect("cross decoder longer");
assert_eq!(out.len(), decoder_seq * hidden);
}
#[test]
fn test_cross_attention_error_wrong_kv_shape() {
let h = 2;
let d = 4;
let hidden = h * d;
let config = CrossAttentionConfig::new(h, d, hidden, hidden)
.expect("cross config");
let q = make_tensor(5 * hidden, 1.0);
let kv = vec![0.0_f32; 3]; let result = cross_attention(&q, &kv, None, &config, 5, 8);
assert!(result.is_err(), "should error on wrong kv shape");
}
#[test]
fn test_cross_attention_equals_self_attention_when_same_input() {
let h = 2;
let d = 4;
let seq = 6;
let hidden = h * d;
let config = CrossAttentionConfig::new(h, d, hidden, hidden)
.expect("cross config");
let qkv = make_tensor(seq * hidden, 1.3);
let cross_out = cross_attention(&qkv, &qkv, None, &config, seq, seq)
.expect("cross self");
let scale = config.scale;
let mut ref_out = vec![0.0_f32; seq * hidden];
for head in 0..h {
for q_pos in 0..seq {
let q_off = q_pos * hidden + head * d;
let q_vec = &qkv[q_off..q_off + d];
let mut scores: Vec<f32> = (0..seq)
.map(|k_pos| {
let k_off = k_pos * hidden + head * d;
dot_f32(q_vec, &qkv[k_off..k_off + d]) * scale
})
.collect();
softmax_inplace(&mut scores);
for (k_pos, &w) in scores.iter().enumerate() {
let v_off = k_pos * hidden + head * d;
for dim in 0..d {
ref_out[q_off + dim] += w * qkv[v_off + dim];
}
}
}
}
assert_close_f32(&cross_out, &ref_out, 1e-5, "cross self-attention equivalence");
}
}