use super::*;
#[test]
fn falsify_gq_001_weight_normalization() {
let gqa = GroupedQueryAttention::new(32, 4, 2);
let q = Tensor::new(
&(0..2 * 5 * 32)
.map(|i| (i as f32 * 0.01).sin())
.collect::<Vec<_>>(),
&[2, 5, 32],
);
let k = Tensor::new(
&(0..2 * 8 * 32)
.map(|i| (i as f32 * 0.02).cos())
.collect::<Vec<_>>(),
&[2, 8, 32],
);
let v = Tensor::ones(&[2, 8, 32]);
let (_, attn_weights) = gqa.forward_qkv(&q, &k, &v, None);
let w_data = attn_weights.data();
let w_shape = attn_weights.shape();
let batch = w_shape[0];
let heads = w_shape[1];
let q_len = w_shape[2];
let kv_len = w_shape[3];
for b in 0..batch {
for h in 0..heads {
for q_pos in 0..q_len {
let row_start = b * heads * q_len * kv_len + h * q_len * kv_len + q_pos * kv_len;
let row_sum: f32 = (0..kv_len).map(|j| w_data[row_start + j]).sum();
assert!(
(row_sum - 1.0).abs() < 1e-4,
"FALSIFIED GQ-001: attn_weights[{b}][{h}][{q_pos}] sums to {row_sum}, expected 1.0"
);
}
}
}
}
#[test]
#[should_panic(expected = "must be divisible")]
fn falsify_gq_004_head_divisibility() {
let _gqa = GroupedQueryAttention::new(64, 8, 3);
}
#[test]
fn falsify_gq_006_mqa_boundary() {
let mqa = GroupedQueryAttention::new(32, 4, 1);
assert_eq!(mqa.num_kv_heads(), 1);
let x = Tensor::ones(&[1, 6, 32]);
let output = mqa.forward(&x);
assert_eq!(output.shape(), &[1, 6, 32]);
for (i, &val) in output.data().iter().enumerate() {
assert!(
val.is_finite(),
"FALSIFIED GQ-006: MQA output[{i}] = {val} (not finite)"
);
}
}
#[test]
fn falsify_gq_002_mha_degeneration() {
let gqa_as_mha = GroupedQueryAttention::new(32, 4, 4);
let mha = MultiHeadAttention::new(32, 4);
assert_eq!(gqa_as_mha.num_kv_heads(), gqa_as_mha.num_heads());
let x = Tensor::ones(&[1, 5, 32]);
let gqa_out = gqa_as_mha.forward(&x);
let (mha_out, _) = mha.forward_self(&x, None);
assert_eq!(gqa_out.shape(), mha_out.shape());
for (i, &val) in gqa_out.data().iter().enumerate() {
assert!(
val.is_finite(),
"FALSIFIED GQ-002: GQA-as-MHA output[{i}] = {val} (not finite)"
);
}
}
#[test]
fn falsify_gq_003_output_convexity() {
let gqa = GroupedQueryAttention::new(16, 4, 2);
let q_data: Vec<f32> = (0..1 * 3 * 16).map(|i| (i as f32 * 0.37).sin()).collect();
let k_data: Vec<f32> = (0..1 * 3 * 16).map(|i| (i as f32 * 0.73).cos()).collect();
let v_data: Vec<f32> = (0..1 * 3 * 16)
.map(|i| (i as f32 * 1.23).sin() * 5.0)
.collect();
let q = Tensor::new(&q_data, &[1, 3, 16]);
let k = Tensor::new(&k_data, &[1, 3, 16]);
let v = Tensor::new(&v_data, &[1, 3, 16]);
let (output, _) = gqa.forward_qkv(&q, &k, &v, None);
let out = output.data();
for (i, &val) in out.iter().enumerate() {
assert!(
val.is_finite(),
"FALSIFIED GQ-003: output[{i}] = {val} (not finite)"
);
}
}
mod gq_proptest_falsify {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn falsify_gq_001_prop_weight_normalization(
config_idx in 0..8usize,
seed in 0..500u32,
) {
let configs = [
(16, 4, 1), (16, 4, 2), (16, 4, 4),
(32, 8, 1), (32, 8, 2), (32, 8, 4), (32, 8, 8),
(64, 8, 2),
];
let (embed_dim, num_heads, num_kv_heads) = configs[config_idx];
let gqa = GroupedQueryAttention::new(embed_dim, num_heads, num_kv_heads);
let seq_len = 4;
let data: Vec<f32> = (0..1 * seq_len * embed_dim)
.map(|i| ((i as f32 + seed as f32) * 0.37).sin())
.collect();
let x = Tensor::new(&data, &[1, seq_len, embed_dim]);
let (_, attn_weights) = gqa.forward_self(&x, None);
let w = attn_weights.data();
let shape = attn_weights.shape();
let (b, h, tgt, src) = (shape[0], shape[1], shape[2], shape[3]);
for bi in 0..b {
for hi in 0..h {
for ti in 0..tgt {
let offset = bi * h * tgt * src + hi * tgt * src + ti * src;
let row_sum: f32 = w[offset..offset + src].iter().sum();
prop_assert!(
(row_sum - 1.0).abs() < 1e-4,
"FALSIFIED GQ-001-prop: row sum = {} at b={},h={},q={}",
row_sum, bi, hi, ti
);
}
}
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(30))]
#[test]
fn falsify_gq_002_prop_mha_degeneration(
seed in 0..500u32,
seq_len in 2..=6usize,
) {
let embed_dim = 32;
let num_heads = 8;
let gqa = GroupedQueryAttention::new(embed_dim, num_heads, num_heads);
let data: Vec<f32> = (0..1 * seq_len * embed_dim)
.map(|i| ((i as f32 + seed as f32) * 0.37).sin())
.collect();
let x = Tensor::new(&data, &[1, seq_len, embed_dim]);
let (output, attn_weights) = gqa.forward_self(&x, None);
prop_assert_eq!(output.shape(), &[1, seq_len, embed_dim]);
prop_assert_eq!(attn_weights.shape(), &[1, num_heads, seq_len, seq_len]);
for &v in output.data() {
prop_assert!(v.is_finite(), "FALSIFIED GQ-002-prop: non-finite output");
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(30))]
#[test]
fn falsify_gq_006_prop_mqa_boundary(
seed in 0..500u32,
seq_len in 2..=6usize,
) {
let embed_dim = 16;
let num_heads = 4;
let mqa = GroupedQueryAttention::new(embed_dim, num_heads, 1);
let data: Vec<f32> = (0..1 * seq_len * embed_dim)
.map(|i| ((i as f32 + seed as f32) * 0.73).cos())
.collect();
let x = Tensor::new(&data, &[1, seq_len, embed_dim]);
let (output, attn_weights) = mqa.forward_self(&x, None);
prop_assert_eq!(output.shape(), &[1, seq_len, embed_dim]);
prop_assert_eq!(attn_weights.shape(), &[1, num_heads, seq_len, seq_len]);
let w = attn_weights.data();
let shape = attn_weights.shape();
let (h, tgt, src) = (shape[1], shape[2], shape[3]);
for hi in 0..h {
for ti in 0..tgt {
let offset = hi * tgt * src + ti * src;
let row_sum: f32 = w[offset..offset + src].iter().sum();
prop_assert!(
(row_sum - 1.0).abs() < 1e-4,
"FALSIFIED GQ-006-prop: MQA row sum = {} at h={},q={}",
row_sum, hi, ti
);
}
}
}
}
}