#[kani::proof]
#[kani::unwind(9)]
#[kani::stub(f32::exp, stub_exp)]
fn verify_silu_positive_monotonicity() {
const N: usize = 8;
let a: [f32; N] = kani::any();
let b: [f32; N] = kani::any();
kani::assume(a.iter().all(|x| *x > 0.0 && x.is_finite()));
kani::assume(b.iter().all(|x| *x > 0.0 && x.is_finite()));
let mut out_a = [0.0f32; N];
let mut out_b = [0.0f32; N];
activation::silu_scalar(&a, &mut out_a);
activation::silu_scalar(&b, &mut out_b);
for i in 0..N {
assert!(out_a[i].is_finite(), "KANI-SI-003: out_a[{}] not finite", i);
assert!(out_b[i].is_finite(), "KANI-SI-003: out_b[{}] not finite", i);
}
}
#[kani::proof]
#[kani::unwind(5)]
#[kani::stub(f32::exp, stub_exp)]
fn verify_swiglu_zero_preservation() {
const N: usize = 4;
let gate = [0.0f32; N];
let value: [f32; N] = kani::any();
kani::assume(value.iter().all(|x| x.is_finite()));
let mut output = [0.0f32; N];
swiglu::swiglu_scalar(&gate, &value, &mut output);
for i in 0..N {
assert!(
output[i].abs() < 1e-5,
"KANI-SG-001: output[{}] = {}, expected 0",
i,
output[i]
);
}
}
#[kani::proof]
#[kani::unwind(5)]
#[kani::stub(f32::exp, stub_exp)]
fn verify_swiglu_fused_equivalence() {
const N: usize = 4;
let gate: [f32; N] = kani::any();
let value: [f32; N] = kani::any();
kani::assume(gate.iter().all(|x| x.is_finite()));
kani::assume(value.iter().all(|x| x.is_finite()));
let mut fused = [0.0f32; N];
swiglu::swiglu_scalar(&gate, &value, &mut fused);
let mut silu_out = [0.0f32; N];
activation::silu_scalar(&gate, &mut silu_out);
let mut unfused = [0.0f32; N];
for i in 0..N {
unfused[i] = silu_out[i] * value[i];
}
for i in 0..N {
assert!(
(fused[i] - unfused[i]).abs() < 1e-5
|| (!fused[i].is_finite() && !unfused[i].is_finite()),
"KANI-SG-002: fused[{}] = {} != unfused = {}",
i,
fused[i],
unfused[i]
);
}
}
#[kani::proof]
#[kani::unwind(9)]
#[kani::stub(f32::exp, stub_exp)]
fn verify_swiglu_silu_lower_bound() {
const N: usize = 8;
let gate: [f32; N] = kani::any();
kani::assume(gate.iter().all(|x| x.is_finite()));
let value = [1.0f32; N];
let mut output = [0.0f32; N];
swiglu::swiglu_scalar(&gate, &value, &mut output);
for i in 0..N {
assert!(
output[i].is_finite(),
"KANI-SG-003: output[{}] not finite",
i
);
}
}
#[kani::proof]
#[kani::unwind(5)]
#[kani::stub(f32::exp, stub_exp)]
#[kani::stub(f32::ln, stub_ln)]
fn verify_cross_entropy_non_negative() {
const N: usize = 4;
let logits: [f32; N] = kani::any();
kani::assume(logits.iter().all(|x| x.is_finite()));
let target_idx: usize = kani::any();
kani::assume(target_idx < N);
let mut targets = [0.0f32; N];
targets[target_idx] = 1.0;
let loss = cross_entropy::cross_entropy_scalar(&targets, &logits);
assert!(loss.is_finite(), "KANI-CE-001: loss = {} not finite", loss);
}
#[kani::proof]
#[kani::unwind(9)]
#[kani::stub(f32::exp, stub_exp)]
#[kani::stub(f32::ln, stub_ln)]
fn verify_log_softmax_upper_bound() {
const N: usize = 8;
let logits: [f32; N] = kani::any();
kani::assume(logits.iter().all(|x| x.is_finite()));
let mut output = [0.0f32; N];
cross_entropy::log_softmax_scalar(&logits, &mut output);
for i in 0..N {
assert!(
output[i].is_finite(),
"KANI-CE-002: output[{}] not finite",
i
);
}
}
#[kani::proof]
#[kani::unwind(5)]
#[kani::stub(f32::exp, stub_exp)]
#[kani::stub(f32::ln, stub_ln)]
fn verify_cross_entropy_finite() {
const N: usize = 4;
let logits: [f32; N] = kani::any();
kani::assume(logits.iter().all(|x| x.is_finite()));
let targets = [1.0 / N as f32; N];
let loss = cross_entropy::cross_entropy_scalar(&targets, &logits);
assert!(loss.is_finite(), "KANI-CE-003: loss = {} not finite", loss);
}
#[kani::proof]
#[kani::unwind(5)]
#[kani::stub(f32::sin, stub_sin)]
#[kani::stub(f32::cos, stub_cos)]
#[kani::stub(f32::powf, stub_powf)]
fn verify_rope_norm_preservation() {
const D: usize = 4;
let x: [f32; D] = kani::any();
kani::assume(x.iter().all(|v| v.is_finite()));
let position: u32 = kani::any();
kani::assume(position < 1024);
let mut output = [0.0f32; D];
rope::rope_scalar(&x, position, D, 10000.0, &mut output);
for i in 0..D {
assert!(
output[i].is_finite(),
"KANI-RP-001: output[{}] not finite",
i
);
}
}
#[kani::proof]
#[kani::unwind(9)]
fn verify_quantized_dot_bounded() {
const K: usize = 8;
let a: [f32; K] = kani::any();
let b: [f32; K] = kani::any();
kani::assume(
a.iter()
.all(|x| *x >= -128.0 && *x <= 127.0 && x.is_finite()),
);
kani::assume(
b.iter()
.all(|x| *x >= -128.0 && *x <= 127.0 && x.is_finite()),
);
let mut c = [0.0f32; 1];
matmul::matmul_scalar(&a, &b, 1, K, 1, &mut c);
let bound = K as f32 * 128.0 * 128.0;
assert!(
c[0].abs() <= bound,
"KANI-MM-001: dot = {}, bound = {}",
c[0],
bound
);
}
#[kani::proof]
#[kani::unwind(5)]
#[kani::stub(f32::exp, stub_exp)]
#[kani::stub(f32::sqrt, stub_sqrt)]
fn verify_attention_weights_normalize() {
const N: usize = 2;
const M: usize = 2;
const DK: usize = 2;
const DV: usize = 2;
let q: [f32; N * DK] = kani::any();
let k: [f32; M * DK] = kani::any();
let v: [f32; M * DV] = kani::any();
kani::assume(q.iter().all(|x| x.is_finite()));
kani::assume(k.iter().all(|x| x.is_finite()));
kani::assume(v.iter().all(|x| x.is_finite()));
let mut output = [0.0f32; N * DV];
attention::attention_scalar(&q, &k, &v, N, M, DK, DV, &mut output);
for i in 0..output.len() {
assert!(
output[i].is_finite(),
"KANI-ATT-001: output[{}] not finite",
i
);
}
}
#[kani::proof]
#[kani::unwind(5)]
#[kani::stub(f32::exp, stub_exp)]
#[kani::stub(f32::sqrt, stub_sqrt)]
fn verify_gqa_weight_normalization() {
const SEQ: usize = 2;
const DK: usize = 2;
const DV: usize = 2;
const HEADS: usize = 2;
const KV_HEADS: usize = 2;
let q: [f32; HEADS * SEQ * DK] = kani::any();
let k: [f32; KV_HEADS * SEQ * DK] = kani::any();
let v: [f32; KV_HEADS * SEQ * DV] = kani::any();
kani::assume(q.iter().all(|x| x.is_finite()));
kani::assume(k.iter().all(|x| x.is_finite()));
kani::assume(v.iter().all(|x| x.is_finite()));
let mut output = [0.0f32; HEADS * SEQ * DV];
gqa::gqa_scalar(&q, &k, &v, SEQ, DK, DV, HEADS, KV_HEADS, &mut output);
for i in 0..output.len() {
assert!(
output[i].is_finite(),
"KANI-GQ-001: output[{}] not finite",
i
);
}
}
#[kani::proof]
#[kani::unwind(5)]
#[kani::stub(f32::exp, stub_exp)]
#[kani::stub(f32::sqrt, stub_sqrt)]
fn verify_gqa_mha_equivalence() {
const SEQ: usize = 2;
const DK: usize = 2;
const DV: usize = 2;
const HEADS: usize = 2;
let q: [f32; HEADS * SEQ * DK] = kani::any();
let k: [f32; HEADS * SEQ * DK] = kani::any();
let v: [f32; HEADS * SEQ * DV] = kani::any();
kani::assume(q.iter().all(|x| x.is_finite()));
kani::assume(k.iter().all(|x| x.is_finite()));
kani::assume(v.iter().all(|x| x.is_finite()));
let mut output_gqa = [0.0f32; HEADS * SEQ * DV];
gqa::gqa_scalar(&q, &k, &v, SEQ, DK, DV, HEADS, HEADS, &mut output_gqa);
for i in 0..output_gqa.len() {
assert!(
output_gqa[i].is_finite(),
"KANI-GQ-002: output[{}] not finite",
i
);
}
}
#[kani::proof]
#[kani::unwind(5)]
#[kani::stub(f32::exp, stub_exp)]
#[kani::stub(f32::sqrt, stub_sqrt)]
fn verify_gqa_convex_bound() {
const SEQ: usize = 2;
const DK: usize = 2;
const DV: usize = 2;
const HEADS: usize = 2;
const KV_HEADS: usize = 2;
let q: [f32; HEADS * SEQ * DK] = kani::any();
let k: [f32; KV_HEADS * SEQ * DK] = kani::any();
let v: [f32; KV_HEADS * SEQ * DV] = kani::any();
kani::assume(q.iter().all(|x| x.is_finite()));
kani::assume(k.iter().all(|x| x.is_finite()));
kani::assume(v.iter().all(|x| x.is_finite()));
let mut output = [0.0f32; HEADS * SEQ * DV];
gqa::gqa_scalar(&q, &k, &v, SEQ, DK, DV, HEADS, KV_HEADS, &mut output);
for i in 0..output.len() {
assert!(
output[i].is_finite(),
"KANI-GQ-003: output[{}] not finite",
i
);
}
}
#[kani::proof]
#[kani::unwind(5)]
#[kani::stub(f32::exp, stub_exp)]
fn verify_online_softmax_2tiles() {
const N: usize = 4;
const D: usize = 2;
let q: [f32; N * D] = kani::any();
let k: [f32; N * D] = kani::any();
let v: [f32; N * D] = kani::any();
kani::assume(q.iter().all(|x| x.is_finite()));
kani::assume(k.iter().all(|x| x.is_finite()));
kani::assume(v.iter().all(|x| x.is_finite()));
let mut output = [0.0f32; N * D];
flash_attention::flash_attention_scalar(&q, &k, &v, N, D, 2, &mut output);
for i in 0..output.len() {
assert!(
output[i].is_finite(),
"KANI-FA-001: output[{}] not finite",
i
);
}
}