aprender-core 0.31.2

Next-generation machine learning library in pure Rust
// CONTRACT: gqa-kernel-v1.yaml
// HASH: sha256:d0e1f2a3b4c56789
// Generated by: pv probar --binding
// Obligation tests for GroupedQueryAttention standalone API.

use aprender::autograd::Tensor;
use aprender::nn::{GroupedQueryAttention, Module};
use proptest::prelude::*;

/// Strategy: valid (embed_dim, num_heads, num_kv_heads) triples.
///
/// Constraints from constructor:
///   embed_dim % num_heads == 0
///   num_heads % num_kv_heads == 0
///   embed_dim > 0, num_heads > 0, num_kv_heads > 0
fn gqa_config_strategy() -> impl Strategy<Value = (usize, usize, usize)> {
    // head_dim in {4, 8}, num_kv_heads in {1, 2}, groups in {1, 2, 4}
    // This gives manageable embed_dims ≤ 64
    prop::sample::select(vec![
        (16, 4, 1), // embed=16, 4 Q heads, 1 KV head, groups=4
        (16, 4, 2), // embed=16, 4 Q heads, 2 KV heads, groups=2
        (16, 4, 4), // embed=16, 4 Q heads, 4 KV heads (MHA), groups=1
        (32, 8, 1), // embed=32, 8 Q heads, 1 KV head (MQA), groups=8
        (32, 8, 2), // embed=32, 8 Q heads, 2 KV heads, groups=4
        (32, 8, 4), // embed=32, 8 Q heads, 4 KV heads, groups=2
        (32, 8, 8), // embed=32, 8 Q heads, 8 KV heads (MHA), groups=1
        (64, 8, 2), // embed=64, 8 Q heads, 2 KV heads, head_dim=8
    ])
}

proptest! {
    #![proptest_config(ProptestConfig::with_cases(20))]

    /// Obligation: Output shape correctness (invariant)
    /// Formal: shape(GQA(Q,K,V)) = shape(Q)
    #[test]
    fn prop_output_shape(
        (embed_dim, num_heads, num_kv_heads) in gqa_config_strategy(),
        batch in 1_usize..=3,
        seq_len in 1_usize..=8,
    ) {
        let gqa = GroupedQueryAttention::new(embed_dim, num_heads, num_kv_heads);
        let x = Tensor::ones(&[batch, seq_len, embed_dim]);
        let output = gqa.forward(&x);

        // Output shape must equal input shape (identity dimensions)
        prop_assert_eq!(output.shape(), &[batch, seq_len, embed_dim]);
    }

    /// Obligation: KV head broadcast correctness (invariant)
    /// Formal: num_q_heads / num_kv_heads is integer, each KV head shared
    #[test]
    fn prop_kv_broadcast(
        (embed_dim, num_heads, num_kv_heads) in gqa_config_strategy(),
        batch in 1_usize..=2,
        tgt_len in 1_usize..=4,
        src_len in 1_usize..=4,
    ) {
        let gqa = GroupedQueryAttention::new(embed_dim, num_heads, num_kv_heads);

        // forward_qkv allows different Q vs K/V seq lengths
        let q = Tensor::ones(&[batch, tgt_len, embed_dim]);
        let k = Tensor::ones(&[batch, src_len, embed_dim]);
        let v = Tensor::ones(&[batch, src_len, embed_dim]);

        let (output, attn_weights) = gqa.forward_qkv(&q, &k, &v, None);

        // Output shape matches Q's batch and tgt_len
        prop_assert_eq!(output.shape(), &[batch, tgt_len, embed_dim]);
        // Attention weights: [batch, num_heads, tgt_len, src_len]
        // (KV heads expanded to match Q heads)
        prop_assert_eq!(attn_weights.shape(), &[batch, num_heads, tgt_len, src_len]);
    }

    /// Obligation: Attention weights sum to 1 per query (invariant)
    /// Formal: sum(attn_weights, dim=-1) ≈ 1 for each query head
    #[test]
    fn prop_attention_weights_sum_to_1(
        (embed_dim, num_heads, num_kv_heads) in gqa_config_strategy(),
    ) {
        let gqa = GroupedQueryAttention::new(embed_dim, num_heads, num_kv_heads);
        let batch = 1_usize;
        let seq_len = 4_usize;

        let x = Tensor::ones(&[batch, seq_len, embed_dim]);
        let (_output, attn_weights) = gqa.forward_self(&x, None);

        // attn_weights shape: [batch, num_heads, seq_len, seq_len]
        let data = attn_weights.data();
        let shape = attn_weights.shape();
        let (b, h, tgt, src) = (shape[0], shape[1], shape[2], shape[3]);

        // For each (batch, head, query_position), weights over keys should sum to ~1.0
        for bi in 0..b {
            for hi in 0..h {
                for ti in 0..tgt {
                    let offset = bi * h * tgt * src + hi * tgt * src + ti * src;
                    let row_sum: f32 = data[offset..offset + src].iter().sum();
                    prop_assert!(
                        (row_sum - 1.0).abs() < 1e-4,
                        "Attention weights row sum = {row_sum}, expected ≈ 1.0 at batch={bi}, head={hi}, query={ti}"
                    );
                }
            }
        }
    }

    /// Obligation: Output finite for finite inputs (invariant)
    #[test]
    fn prop_finite_output(
        (embed_dim, num_heads, num_kv_heads) in gqa_config_strategy(),
        batch in 1_usize..=2,
        seq_len in 1_usize..=4,
    ) {
        let gqa = GroupedQueryAttention::new(embed_dim, num_heads, num_kv_heads);
        let x = Tensor::ones(&[batch, seq_len, embed_dim]);
        let (output, attn_weights) = gqa.forward_self(&x, None);

        // All output values must be finite
        for &v in output.data() {
            prop_assert!(v.is_finite(), "Non-finite output value: {v}");
        }
        for &v in attn_weights.data() {
            prop_assert!(v.is_finite(), "Non-finite attention weight: {v}");
        }
    }

    /// Obligation: GQA with num_kv_heads == num_heads equivalent to MHA (equivalence)
    ///
    /// When groups=1 (no KV sharing), GQA degenerates to standard MHA.
    /// Both should produce identical output shapes and the same attn weight structure.
    #[test]
    fn prop_gqa_mha_equivalence(
        batch in 1_usize..=2,
        seq_len in 1_usize..=4,
    ) {
        // GQA with kv_heads == q_heads is exactly MHA
        let embed_dim = 32;
        let num_heads = 8;
        let gqa = GroupedQueryAttention::new(embed_dim, num_heads, num_heads);

        let x = Tensor::ones(&[batch, seq_len, embed_dim]);
        let (output, attn_weights) = gqa.forward_self(&x, None);

        // Shape correctness
        prop_assert_eq!(output.shape(), &[batch, seq_len, embed_dim]);
        prop_assert_eq!(attn_weights.shape(), &[batch, num_heads, seq_len, seq_len]);

        // All finite
        for &v in output.data() {
            prop_assert!(v.is_finite(), "Non-finite MHA-equivalent output: {v}");
        }
    }
}