#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Tsf001Verdict { Pass, Fail }
#[must_use]
pub const fn verdict_from_qkv_shape(
n_h: u64,
n_kv: u64,
d_k: u64,
q_dim: u64,
k_dim: u64,
v_dim: u64,
) -> Tsf001Verdict {
if n_h == 0 || n_kv == 0 || d_k == 0 { return Tsf001Verdict::Fail; }
if q_dim != n_h * d_k { return Tsf001Verdict::Fail; }
if k_dim != n_kv * d_k { return Tsf001Verdict::Fail; }
if v_dim != n_kv * d_k { return Tsf001Verdict::Fail; }
if k_dim != v_dim { return Tsf001Verdict::Fail; } Tsf001Verdict::Pass
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Tsf002Verdict { Pass, Fail }
#[must_use]
pub const fn verdict_from_gqa_grouping(n_h: u64, n_kv: u64) -> Tsf002Verdict {
if n_h == 0 || n_kv == 0 { return Tsf002Verdict::Fail; }
if n_kv > n_h { return Tsf002Verdict::Fail; }
if !n_h.is_multiple_of(n_kv) { return Tsf002Verdict::Fail; }
Tsf002Verdict::Pass
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Tsf003Verdict { Pass, Fail }
#[must_use]
pub fn verdict_from_residual_shape(input_shape: &[u64], output_shape: &[u64]) -> Tsf003Verdict {
if input_shape.is_empty() || output_shape.is_empty() { return Tsf003Verdict::Fail; }
if input_shape == output_shape { Tsf003Verdict::Pass } else { Tsf003Verdict::Fail }
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Tsf004Verdict { Pass, Fail }
#[must_use]
pub const fn verdict_from_swiglu_shape(
h: u64,
d_ff: u64,
gate_in_dim: u64,
gate_out_dim: u64,
down_in_dim: u64,
down_out_dim: u64,
) -> Tsf004Verdict {
if h == 0 || d_ff == 0 { return Tsf004Verdict::Fail; }
if d_ff <= h { return Tsf004Verdict::Fail; } if gate_in_dim != h { return Tsf004Verdict::Fail; }
if gate_out_dim != d_ff { return Tsf004Verdict::Fail; }
if down_in_dim != d_ff { return Tsf004Verdict::Fail; }
if down_out_dim != h { return Tsf004Verdict::Fail; }
Tsf004Verdict::Pass
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Tsf005Verdict { Pass, Fail }
#[must_use]
pub const fn verdict_from_lm_head_shape(output_dim: u64, vocab_size: u64) -> Tsf005Verdict {
if output_dim == 0 || vocab_size == 0 { return Tsf005Verdict::Fail; }
if output_dim == vocab_size { Tsf005Verdict::Pass } else { Tsf005Verdict::Fail }
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Tsf006Verdict { Pass, Fail }
#[must_use]
pub fn verdict_from_simd_shape_parity(scalar: &[u64], simd: &[u64]) -> Tsf006Verdict {
if scalar.is_empty() || simd.is_empty() { return Tsf006Verdict::Fail; }
if scalar.len() != simd.len() { return Tsf006Verdict::Fail; }
for (&s, &v) in scalar.iter().zip(simd.iter()) {
if s != v { return Tsf006Verdict::Fail; }
}
Tsf006Verdict::Pass
}
#[cfg(test)]
mod tests {
use super::*;
#[test] fn tsf001_pass_qwen2_7b() {
assert_eq!(
verdict_from_qkv_shape(28, 4, 128, 3584, 512, 512),
Tsf001Verdict::Pass
);
}
#[test] fn tsf001_pass_full_attention() {
assert_eq!(
verdict_from_qkv_shape(32, 32, 128, 4096, 4096, 4096),
Tsf001Verdict::Pass
);
}
#[test] fn tsf001_fail_wrong_q_dim() {
assert_eq!(
verdict_from_qkv_shape(28, 4, 128, 1024, 512, 512),
Tsf001Verdict::Fail
);
}
#[test] fn tsf001_fail_k_v_dim_mismatch() {
assert_eq!(
verdict_from_qkv_shape(28, 4, 128, 3584, 512, 1024),
Tsf001Verdict::Fail
);
}
#[test] fn tsf001_fail_zero() {
assert_eq!(
verdict_from_qkv_shape(0, 4, 128, 0, 512, 512),
Tsf001Verdict::Fail
);
}
#[test] fn tsf002_pass_canonical() {
assert_eq!(verdict_from_gqa_grouping(28, 4), Tsf002Verdict::Pass);
}
#[test] fn tsf002_pass_full_mha() {
assert_eq!(verdict_from_gqa_grouping(32, 32), Tsf002Verdict::Pass);
}
#[test] fn tsf002_pass_mqa() {
assert_eq!(verdict_from_gqa_grouping(32, 1), Tsf002Verdict::Pass);
}
#[test] fn tsf002_fail_indivisible() {
assert_eq!(verdict_from_gqa_grouping(28, 5), Tsf002Verdict::Fail);
}
#[test] fn tsf002_fail_n_kv_above_n_h() {
assert_eq!(verdict_from_gqa_grouping(4, 8), Tsf002Verdict::Fail);
}
#[test] fn tsf002_fail_zero() {
assert_eq!(verdict_from_gqa_grouping(0, 4), Tsf002Verdict::Fail);
assert_eq!(verdict_from_gqa_grouping(28, 0), Tsf002Verdict::Fail);
}
#[test] fn tsf003_pass_match() {
let s = vec![1_u64, 16, 4096];
assert_eq!(verdict_from_residual_shape(&s, &s), Tsf003Verdict::Pass);
}
#[test] fn tsf003_fail_drift() {
let input = vec![1_u64, 16, 4096];
let output = vec![1_u64, 16, 4097];
assert_eq!(verdict_from_residual_shape(&input, &output), Tsf003Verdict::Fail);
}
#[test] fn tsf003_fail_extra_dim() {
let input = vec![1_u64, 16, 4096];
let output = vec![1_u64, 16, 4096, 1];
assert_eq!(verdict_from_residual_shape(&input, &output), Tsf003Verdict::Fail);
}
#[test] fn tsf004_pass_canonical() {
assert_eq!(
verdict_from_swiglu_shape(4096, 14336, 4096, 14336, 14336, 4096),
Tsf004Verdict::Pass
);
}
#[test] fn tsf004_fail_d_ff_le_h() {
assert_eq!(
verdict_from_swiglu_shape(4096, 4096, 4096, 4096, 4096, 4096),
Tsf004Verdict::Fail
);
}
#[test] fn tsf004_fail_gate_in_wrong() {
assert_eq!(
verdict_from_swiglu_shape(4096, 14336, 2048, 14336, 14336, 4096),
Tsf004Verdict::Fail
);
}
#[test] fn tsf004_fail_down_out_wrong() {
assert_eq!(
verdict_from_swiglu_shape(4096, 14336, 4096, 14336, 14336, 8192),
Tsf004Verdict::Fail
);
}
#[test] fn tsf004_fail_zero() {
assert_eq!(
verdict_from_swiglu_shape(0, 14336, 0, 14336, 14336, 4096),
Tsf004Verdict::Fail
);
}
#[test] fn tsf005_pass_canonical() {
assert_eq!(verdict_from_lm_head_shape(152064, 152064), Tsf005Verdict::Pass);
}
#[test] fn tsf005_fail_drift() {
assert_eq!(verdict_from_lm_head_shape(151936, 152064), Tsf005Verdict::Fail);
}
#[test] fn tsf005_fail_zero() {
assert_eq!(verdict_from_lm_head_shape(0, 152064), Tsf005Verdict::Fail);
assert_eq!(verdict_from_lm_head_shape(152064, 0), Tsf005Verdict::Fail);
}
#[test] fn tsf006_pass_identical() {
let s = vec![1_u64, 16, 4096];
assert_eq!(verdict_from_simd_shape_parity(&s, &s), Tsf006Verdict::Pass);
}
#[test] fn tsf006_fail_drift() {
let scalar = vec![1_u64, 16, 4096];
let simd = vec![1_u64, 16, 4097];
assert_eq!(verdict_from_simd_shape_parity(&scalar, &simd), Tsf006Verdict::Fail);
}
#[test] fn tsf006_fail_length() {
let scalar = vec![1_u64];
let simd = vec![1_u64, 2];
assert_eq!(verdict_from_simd_shape_parity(&scalar, &simd), Tsf006Verdict::Fail);
}
#[test] fn tsf006_fail_empty() {
assert_eq!(verdict_from_simd_shape_parity(&[], &[]), Tsf006Verdict::Fail);
}
}