1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
// CONTRACT: gqa-kernel-v1.yaml
// HASH: sha256:d0e1f2a3b4c56789
// Generated by: pv probar --binding
// Obligation tests for GroupedQueryAttention standalone API.
use aprender::autograd::Tensor;
use aprender::nn::{GroupedQueryAttention, Module};
use proptest::prelude::*;
/// Strategy: valid (embed_dim, num_heads, num_kv_heads) triples.
///
/// Constraints from constructor:
/// embed_dim % num_heads == 0
/// num_heads % num_kv_heads == 0
/// embed_dim > 0, num_heads > 0, num_kv_heads > 0
fn gqa_config_strategy() -> impl Strategy<Value = (usize, usize, usize)> {
// head_dim in {4, 8}, num_kv_heads in {1, 2}, groups in {1, 2, 4}
// This gives manageable embed_dims ≤ 64
prop::sample::select(vec![
(16, 4, 1), // embed=16, 4 Q heads, 1 KV head, groups=4
(16, 4, 2), // embed=16, 4 Q heads, 2 KV heads, groups=2
(16, 4, 4), // embed=16, 4 Q heads, 4 KV heads (MHA), groups=1
(32, 8, 1), // embed=32, 8 Q heads, 1 KV head (MQA), groups=8
(32, 8, 2), // embed=32, 8 Q heads, 2 KV heads, groups=4
(32, 8, 4), // embed=32, 8 Q heads, 4 KV heads, groups=2
(32, 8, 8), // embed=32, 8 Q heads, 8 KV heads (MHA), groups=1
(64, 8, 2), // embed=64, 8 Q heads, 2 KV heads, head_dim=8
])
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(20))]
/// Obligation: Output shape correctness (invariant)
/// Formal: shape(GQA(Q,K,V)) = shape(Q)
#[test]
fn prop_output_shape(
(embed_dim, num_heads, num_kv_heads) in gqa_config_strategy(),
batch in 1_usize..=3,
seq_len in 1_usize..=8,
) {
let gqa = GroupedQueryAttention::new(embed_dim, num_heads, num_kv_heads);
let x = Tensor::ones(&[batch, seq_len, embed_dim]);
let output = gqa.forward(&x);
// Output shape must equal input shape (identity dimensions)
prop_assert_eq!(output.shape(), &[batch, seq_len, embed_dim]);
}
/// Obligation: KV head broadcast correctness (invariant)
/// Formal: num_q_heads / num_kv_heads is integer, each KV head shared
#[test]
fn prop_kv_broadcast(
(embed_dim, num_heads, num_kv_heads) in gqa_config_strategy(),
batch in 1_usize..=2,
tgt_len in 1_usize..=4,
src_len in 1_usize..=4,
) {
let gqa = GroupedQueryAttention::new(embed_dim, num_heads, num_kv_heads);
// forward_qkv allows different Q vs K/V seq lengths
let q = Tensor::ones(&[batch, tgt_len, embed_dim]);
let k = Tensor::ones(&[batch, src_len, embed_dim]);
let v = Tensor::ones(&[batch, src_len, embed_dim]);
let (output, attn_weights) = gqa.forward_qkv(&q, &k, &v, None);
// Output shape matches Q's batch and tgt_len
prop_assert_eq!(output.shape(), &[batch, tgt_len, embed_dim]);
// Attention weights: [batch, num_heads, tgt_len, src_len]
// (KV heads expanded to match Q heads)
prop_assert_eq!(attn_weights.shape(), &[batch, num_heads, tgt_len, src_len]);
}
/// Obligation: Attention weights sum to 1 per query (invariant)
/// Formal: sum(attn_weights, dim=-1) ≈ 1 for each query head
#[test]
fn prop_attention_weights_sum_to_1(
(embed_dim, num_heads, num_kv_heads) in gqa_config_strategy(),
) {
let gqa = GroupedQueryAttention::new(embed_dim, num_heads, num_kv_heads);
let batch = 1_usize;
let seq_len = 4_usize;
let x = Tensor::ones(&[batch, seq_len, embed_dim]);
let (_output, attn_weights) = gqa.forward_self(&x, None);
// attn_weights shape: [batch, num_heads, seq_len, seq_len]
let data = attn_weights.data();
let shape = attn_weights.shape();
let (b, h, tgt, src) = (shape[0], shape[1], shape[2], shape[3]);
// For each (batch, head, query_position), weights over keys should sum to ~1.0
for bi in 0..b {
for hi in 0..h {
for ti in 0..tgt {
let offset = bi * h * tgt * src + hi * tgt * src + ti * src;
let row_sum: f32 = data[offset..offset + src].iter().sum();
prop_assert!(
(row_sum - 1.0).abs() < 1e-4,
"Attention weights row sum = {row_sum}, expected ≈ 1.0 at batch={bi}, head={hi}, query={ti}"
);
}
}
}
}
/// Obligation: Output finite for finite inputs (invariant)
#[test]
fn prop_finite_output(
(embed_dim, num_heads, num_kv_heads) in gqa_config_strategy(),
batch in 1_usize..=2,
seq_len in 1_usize..=4,
) {
let gqa = GroupedQueryAttention::new(embed_dim, num_heads, num_kv_heads);
let x = Tensor::ones(&[batch, seq_len, embed_dim]);
let (output, attn_weights) = gqa.forward_self(&x, None);
// All output values must be finite
for &v in output.data() {
prop_assert!(v.is_finite(), "Non-finite output value: {v}");
}
for &v in attn_weights.data() {
prop_assert!(v.is_finite(), "Non-finite attention weight: {v}");
}
}
/// Obligation: GQA with num_kv_heads == num_heads equivalent to MHA (equivalence)
///
/// When groups=1 (no KV sharing), GQA degenerates to standard MHA.
/// Both should produce identical output shapes and the same attn weight structure.
#[test]
fn prop_gqa_mha_equivalence(
batch in 1_usize..=2,
seq_len in 1_usize..=4,
) {
// GQA with kv_heads == q_heads is exactly MHA
let embed_dim = 32;
let num_heads = 8;
let gqa = GroupedQueryAttention::new(embed_dim, num_heads, num_heads);
let x = Tensor::ones(&[batch, seq_len, embed_dim]);
let (output, attn_weights) = gqa.forward_self(&x, None);
// Shape correctness
prop_assert_eq!(output.shape(), &[batch, seq_len, embed_dim]);
prop_assert_eq!(attn_weights.shape(), &[batch, num_heads, seq_len, seq_len]);
// All finite
for &v in output.data() {
prop_assert!(v.is_finite(), "Non-finite MHA-equivalent output: {v}");
}
}
}