use proptest::prelude::*;
fn kv_bytes_per_token_layer(n_kv: u32, d_k: u32, bytes_per_element: u32) -> u64 {
2 * n_kv as u64 * d_k as u64 * bytes_per_element as u64
}
fn kv_total_bytes(layers: u32, seq_len: u32, n_kv: u32, d_k: u32, bpe: u32) -> u64 {
layers as u64 * seq_len as u64 * kv_bytes_per_token_layer(n_kv, d_k, bpe)
}
fn count_attention_layers(layer_types: &[bool]) -> usize {
layer_types.iter().filter(|&&is_attn| is_attn).count()
}
proptest! {
#[test]
fn prop_per_token_kv_bytes(
n_kv in 1u32..32,
d_k in 16u32..256,
bpe in prop::sample::select(vec![2u32, 4]) ) {
let bytes = kv_bytes_per_token_layer(n_kv, d_k, bpe);
let expected = 2 * n_kv as u64 * d_k as u64 * bpe as u64;
prop_assert_eq!(
bytes, expected,
"KV bytes: {} != expected {} for n_kv={}, d_k={}, bpe={}",
bytes, expected, n_kv, d_k, bpe
);
prop_assert!(bytes > 0, "KV bytes must be > 0");
}
#[test]
fn prop_kv_total_monotonic_seq(
layers in 1u32..64,
seq1 in 1u32..4096,
delta in 1u32..4096,
n_kv in 1u32..16,
d_k in 16u32..128
) {
let seq2 = seq1.saturating_add(delta);
let total1 = kv_total_bytes(layers, seq1, n_kv, d_k, 2);
let total2 = kv_total_bytes(layers, seq2, n_kv, d_k, 2);
prop_assert!(
total2 > total1,
"not monotonic: seq1={} -> {} bytes, seq2={} -> {} bytes",
seq1, total1, seq2, total2
);
}
#[test]
fn prop_hybrid_kv_bounded(
layer_types in proptest::collection::vec(proptest::bool::ANY, 2..64usize)
) {
let total = layer_types.len();
let attn = count_attention_layers(&layer_types);
prop_assert!(
attn <= total,
"attention layers {} > total {}", attn, total
);
}
#[test]
fn prop_zero_input_identity(
rows in 2usize..32,
cols in 2usize..32
) {
let zeros = vec![0.0f32; cols];
let weights = vec![1.0f32; rows * cols];
let result: Vec<f32> = (0..rows)
.map(|i| {
(0..cols)
.map(|j| weights[i * cols + j] * zeros[j])
.sum()
})
.collect();
for (i, &val) in result.iter().enumerate() {
prop_assert!(
val == 0.0,
"bias-free W @ zeros != zeros at [{}]: {}", i, val
);
}
}
#[test]
fn prop_kv_total_monotonic_layers(
layers1 in 1u32..64,
delta in 1u32..64,
seq_len in 1u32..2048,
n_kv in 1u32..16,
d_k in 16u32..128
) {
let layers2 = layers1.saturating_add(delta);
let total1 = kv_total_bytes(layers1, seq_len, n_kv, d_k, 2);
let total2 = kv_total_bytes(layers2, seq_len, n_kv, d_k, 2);
prop_assert!(
total2 > total1,
"not monotonic: L1={} -> {}, L2={} -> {}",
layers1, total1, layers2, total2
);
}
#[test]
#[ignore = "SIMD equivalence — trueno domain"]
fn prop_simd_equivalence(
_x in proptest::collection::vec(0u8..=255, 1..32usize)
) {
}
}