#[kani::proof]
#[kani::unwind(33)]
fn verify_relu_nonnegative() {
const N: usize = 32;
let input: [f32; N] = kani::any();
kani::assume(input.iter().all(|x| x.is_finite()));
let mut output = [0.0f32; N];
activation::relu_scalar(&input, &mut output);
for i in 0..N {
assert!(
output[i] >= 0.0,
"KANI-ACT-001: output[{}] = {} < 0",
i,
output[i]
);
}
}
#[kani::proof]
#[kani::unwind(33)]
fn verify_relu_monotonic() {
const N: usize = 32;
let a: [f32; N] = kani::any();
let b: [f32; N] = kani::any();
kani::assume(a.iter().all(|x| x.is_finite()));
kani::assume(b.iter().all(|x| x.is_finite()));
let mut out_a = [0.0f32; N];
let mut out_b = [0.0f32; N];
activation::relu_scalar(&a, &mut out_a);
activation::relu_scalar(&b, &mut out_b);
for i in 0..N {
if a[i] <= b[i] {
assert!(
out_a[i] <= out_b[i],
"KANI-ACT-002: monotonicity violated at {}",
i
);
}
}
}
#[kani::proof]
#[kani::stub(f32::exp, stub_exp)]
fn verify_silu_zero() {
let input = [0.0f32];
let mut output = [0.0f32];
activation::silu_scalar(&input, &mut output);
assert!(
output[0].abs() < 1e-5,
"KANI-SI-001: silu(0) = {}, expected 0",
output[0]
);
}
#[kani::proof]
#[kani::unwind(9)]
#[kani::stub(f32::exp, stub_exp)]
fn verify_silu_lower_bound() {
const N: usize = 8;
let input: [f32; N] = kani::any();
kani::assume(input.iter().all(|x| x.is_finite()));
let mut output = [0.0f32; N];
activation::silu_scalar(&input, &mut output);
for i in 0..N {
assert!(
output[i].is_finite(),
"KANI-SI-002: output[{}] not finite",
i
);
}
}
#[kani::proof]
#[kani::unwind(3)]
#[kani::solver(cadical)]
#[kani::stub(f32::exp, stub_exp)]
fn verify_softmax_normalization() {
const N: usize = 2;
let input: [f32; N] = kani::any();
kani::assume(input.iter().all(|x| x.is_finite()));
let mut output = [0.0f32; N];
softmax::softmax_scalar(&input, &mut output);
let sum: f32 = output.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-5,
"KANI-SM-001: sum = {}, expected 1.0",
sum
);
}
#[kani::proof]
#[kani::unwind(9)]
#[kani::stub(f32::exp, stub_exp)]
fn verify_softmax_positivity() {
const N: usize = 8;
let input: [f32; N] = kani::any();
kani::assume(input.iter().all(|x| x.is_finite()));
let mut output = [0.0f32; N];
softmax::softmax_scalar(&input, &mut output);
for i in 0..N {
assert!(
output[i] > 0.0,
"KANI-SM-002: output[{}] = {} <= 0",
i,
output[i]
);
}
}
#[kani::proof]
#[kani::unwind(3)]
#[kani::solver(cadical)]
#[kani::stub(f32::exp, stub_exp)]
fn verify_softmax_bounded() {
const N: usize = 2;
let input: [f32; N] = kani::any();
kani::assume(input.iter().all(|x| x.is_finite()));
let mut output = [0.0f32; N];
softmax::softmax_scalar(&input, &mut output);
for i in 0..N {
assert!(
output[i] > 0.0 && output[i] <= 1.0,
"KANI-SM-003: output[{}] = {} not in (0, 1]",
i,
output[i]
);
}
}
#[kani::proof]
#[kani::unwind(17)]
fn verify_rmsnorm_finiteness() {
const N: usize = 16;
let input: [f32; N] = kani::any();
let gamma: [f32; N] = kani::any();
kani::assume(input.iter().all(|x| x.is_finite()));
kani::assume(gamma.iter().all(|x| x.is_finite()));
let eps: f32 = kani::any();
kani::assume(eps > 0.0 && eps.is_finite() && eps < 1.0);
let mut output = [0.0f32; N];
rmsnorm::rmsnorm_scalar(&input, &gamma, eps, &mut output);
for i in 0..N {
assert!(
output[i].is_finite(),
"KANI-RN-001: output[{}] is not finite",
i
);
}
}
#[kani::proof]
#[kani::unwind(17)]
fn verify_rms_positive() {
const N: usize = 16;
let input: [f32; N] = kani::any();
kani::assume(input.iter().all(|x| x.is_finite()));
let eps: f32 = kani::any();
kani::assume(eps > 0.0 && eps.is_finite() && eps < 1.0);
let mut sum_sq = 0.0f32;
for i in 0..N {
sum_sq += input[i] * input[i];
}
let denom = sum_sq / N as f32 + eps;
assert!(denom > 0.0, "KANI-RN-002: denominator = {} <= 0", denom);
}
#[kani::proof]
#[kani::unwind(9)]
#[kani::stub(f32::sqrt, stub_sqrt)]
fn verify_layernorm_centering() {
const N: usize = 8;
let input: [f32; N] = kani::any();
kani::assume(input.iter().all(|x| x.is_finite()));
let gamma = [1.0f32; N];
let beta = [0.0f32; N];
let mut output = [0.0f32; N];
layernorm::layernorm_scalar(&input, &gamma, &beta, 1e-5, &mut output);
let sum: f32 = output.iter().sum();
let mean = sum / N as f32;
assert!(mean.is_finite(), "KANI-LN-001: mean not finite");
}
#[kani::proof]
#[kani::unwind(9)]
#[kani::stub(f32::sqrt, stub_sqrt)]
fn verify_layernorm_standardization() {
const N: usize = 8;
let input: [f32; N] = kani::any();
kani::assume(input.iter().all(|x| x.is_finite()));
let gamma = [1.0f32; N];
let beta = [0.0f32; N];
let mut output = [0.0f32; N];
layernorm::layernorm_scalar(&input, &gamma, &beta, 1e-5, &mut output);
for i in 0..N {
assert!(
output[i].is_finite(),
"KANI-LN-002: output[{}] not finite",
i
);
}
}
#[kani::proof]
#[kani::unwind(9)]
#[kani::stub(f32::sqrt, stub_sqrt)]
fn verify_layernorm_denominator_positive() {
const N: usize = 8;
let input: [f32; N] = kani::any();
kani::assume(input.iter().all(|x| x.is_finite()));
let eps: f32 = kani::any();
kani::assume(eps > 0.0 && eps.is_finite() && eps < 1.0);
let mut sum = 0.0f32;
for i in 0..N {
sum += input[i];
}
let mean = sum / N as f32;
let mut var_sum = 0.0f32;
for i in 0..N {
let diff = input[i] - mean;
var_sum += diff * diff;
}
let variance = var_sum / N as f32;
let denom = variance + eps;
assert!(denom > 0.0, "KANI-LN-003: denom = {} <= 0", denom);
}
#[kani::proof]
#[kani::unwind(9)]
#[kani::stub(f32::sqrt, stub_sqrt)]
fn verify_batchnorm_denominator_positive() {
const BATCH: usize = 4;
const CHANNELS: usize = 2;
let input: [f32; BATCH * CHANNELS] = kani::any();
kani::assume(input.iter().all(|x| x.is_finite()));
let eps: f32 = kani::any();
kani::assume(eps > 0.0 && eps.is_finite() && eps < 1.0);
for ch in 0..CHANNELS {
let mut sum = 0.0f32;
for sample in 0..BATCH {
sum += input[sample * CHANNELS + ch];
}
let batch_mean = sum / BATCH as f32;
let mut var_sum = 0.0f32;
for sample in 0..BATCH {
let diff = input[sample * CHANNELS + ch] - batch_mean;
var_sum += diff * diff;
}
let batch_var = var_sum / BATCH as f32;
let denom = batch_var + eps;
assert!(denom > 0.0, "KANI-BN-001: ch {} denom = {} <= 0", ch, denom);
}
}
#[kani::proof]
#[kani::unwind(5)]
#[kani::stub(f32::sqrt, stub_sqrt)]
fn verify_running_variance_nonneg() {
const BATCH: usize = 4;
const CHANNELS: usize = 1;
let input: [f32; BATCH * CHANNELS] = kani::any();
kani::assume(input.iter().all(|x| x.is_finite()));
let gamma = [1.0f32; CHANNELS];
let beta = [0.0f32; CHANNELS];
let mut running_mean = [0.0f32; CHANNELS];
let init_rv: f32 = kani::any();
kani::assume(init_rv >= 0.0 && init_rv.is_finite());
let mut running_var = [init_rv; CHANNELS];
let momentum: f32 = kani::any();
kani::assume(momentum > 0.0 && momentum < 1.0 && momentum.is_finite());
let mut output = [0.0f32; BATCH * CHANNELS];
batchnorm::batchnorm_scalar(
&input,
BATCH,
CHANNELS,
&gamma,
&beta,
1e-5,
&mut running_mean,
&mut running_var,
&mut output,
momentum,
true,
);
assert!(
running_var[0] >= 0.0,
"KANI-BN-002: running_var = {} < 0",
running_var[0]
);
}