#[kani::proof]
#[kani::unwind(5)]
#[kani::stub(f32::sqrt, stub_sqrt)]
#[kani::stub(f32::powi, stub_powi)]
fn verify_adamw_decoupled() {
const N: usize = 4;
let mut params: [f32; N] = kani::any();
let grads: [f32; N] = kani::any();
kani::assume(params.iter().all(|x| x.is_finite()));
kani::assume(grads.iter().all(|x| x.is_finite()));
let mut m = [0.0f32; N];
let mut v = [0.0f32; N];
let lr: f32 = kani::any();
kani::assume(lr > 0.0 && lr < 1.0 && lr.is_finite());
let eps: f32 = kani::any();
kani::assume(eps > 0.0 && eps < 1.0 && eps.is_finite());
adamw::adamw_step_scalar(
&mut params,
&grads,
&mut m,
&mut v,
lr,
0.9,
0.999,
eps,
0.01,
1,
);
for i in 0..N {
assert!(
params[i].is_finite(),
"KANI-AW-001: params[{}] not finite",
i
);
}
}
#[kani::proof]
#[kani::unwind(9)]
#[kani::stub(f32::sqrt, stub_sqrt)]
#[kani::stub(f32::powi, stub_powi)]
fn verify_adamw_moment_positive() {
const N: usize = 8;
let mut params: [f32; N] = kani::any();
let grads: [f32; N] = kani::any();
kani::assume(params.iter().all(|x| x.is_finite()));
kani::assume(grads.iter().all(|x| x.is_finite()));
let mut m = [0.0f32; N];
let mut v = [0.0f32; N];
adamw::adamw_step_scalar(
&mut params,
&grads,
&mut m,
&mut v,
0.001,
0.9,
0.999,
1e-8,
0.01,
1,
);
for i in 0..N {
assert!(v[i] >= 0.0, "KANI-AW-002: v[{}] = {} < 0", i, v[i]);
}
}
#[kani::proof]
#[kani::unwind(5)]
#[kani::stub(f32::sqrt, stub_sqrt)]
#[kani::stub(f32::powi, stub_powi)]
fn verify_adamw_finite_update() {
const N: usize = 4;
let mut params: [f32; N] = kani::any();
let grads: [f32; N] = kani::any();
kani::assume(params.iter().all(|x| x.is_finite() && x.abs() < 100.0));
kani::assume(grads.iter().all(|x| x.is_finite() && x.abs() < 100.0));
let mut m = [0.0f32; N];
let mut v = [0.0f32; N];
let eps: f32 = kani::any();
kani::assume(eps > 1e-10 && eps < 1.0 && eps.is_finite());
adamw::adamw_step_scalar(
&mut params,
&grads,
&mut m,
&mut v,
0.001,
0.9,
0.999,
eps,
0.01,
1,
);
for i in 0..N {
assert!(
params[i].is_finite(),
"KANI-AW-003: params[{}] not finite",
i
);
}
}
#[kani::proof]
#[kani::unwind(9)]
fn verify_conv1d_output_shape() {
const C_IN: usize = 1;
const C_OUT: usize = 1;
const LENGTH: usize = 8;
const KERNEL_SIZE: usize = 3;
const STRIDE: usize = 1;
const PADDING: usize = 0;
const OUT_LEN: usize = (LENGTH + 2 * PADDING - KERNEL_SIZE) / STRIDE + 1;
let input: [f32; C_IN * LENGTH] = kani::any();
let weight: [f32; C_OUT * C_IN * KERNEL_SIZE] = kani::any();
kani::assume(input.iter().all(|x| x.is_finite()));
kani::assume(weight.iter().all(|x| x.is_finite()));
let mut output = [0.0f32; C_OUT * OUT_LEN];
conv1d::conv1d_scalar(
&input,
&weight,
None,
C_IN,
C_OUT,
LENGTH,
KERNEL_SIZE,
STRIDE,
PADDING,
&mut output,
);
for i in 0..output.len() {
assert!(
output[i].is_finite(),
"KANI-CV-001: output[{}] not finite",
i
);
}
}
#[kani::proof]
#[kani::unwind(5)]
fn verify_conv1d_linearity() {
const C_IN: usize = 1;
const C_OUT: usize = 1;
const LENGTH: usize = 4;
const KERNEL_SIZE: usize = 2;
const STRIDE: usize = 1;
const PADDING: usize = 0;
const OUT_LEN: usize = (LENGTH + 2 * PADDING - KERNEL_SIZE) / STRIDE + 1;
let input: [f32; C_IN * LENGTH] = kani::any();
let weight: [f32; C_OUT * C_IN * KERNEL_SIZE] = kani::any();
kani::assume(input.iter().all(|x| x.is_finite() && x.abs() < 10.0));
kani::assume(weight.iter().all(|x| x.is_finite() && x.abs() < 10.0));
let alpha: f32 = kani::any();
kani::assume(alpha.is_finite() && alpha.abs() < 10.0 && alpha.abs() > 0.01);
let mut out1 = [0.0f32; C_OUT * OUT_LEN];
conv1d::conv1d_scalar(
&input,
&weight,
None,
C_IN,
C_OUT,
LENGTH,
KERNEL_SIZE,
STRIDE,
PADDING,
&mut out1,
);
let mut scaled_input = [0.0f32; C_IN * LENGTH];
for i in 0..input.len() {
scaled_input[i] = alpha * input[i];
}
let mut out2 = [0.0f32; C_OUT * OUT_LEN];
conv1d::conv1d_scalar(
&scaled_input,
&weight,
None,
C_IN,
C_OUT,
LENGTH,
KERNEL_SIZE,
STRIDE,
PADDING,
&mut out2,
);
for i in 0..OUT_LEN {
let expected = alpha * out1[i];
let actual = out2[i];
let diff = (expected - actual).abs();
let scale = expected.abs().max(1.0);
assert!(
diff / scale < 1e-4 || diff < 1e-5,
"KANI-CV-002: linearity violated at {}: {} vs {}",
i,
actual,
expected
);
}
}
#[kani::proof]
#[kani::unwind(5)]
fn verify_ssm_causality() {
const STATE_DIM: usize = 2;
const SEQ_LEN: usize = 4;
let a_bar: [f32; STATE_DIM] = kani::any();
let b_bar: [f32; STATE_DIM * SEQ_LEN] = kani::any();
let c: [f32; STATE_DIM] = kani::any();
let mut x1: [f32; SEQ_LEN] = kani::any();
kani::assume(a_bar.iter().all(|v| v.is_finite()));
kani::assume(b_bar.iter().all(|v| v.is_finite()));
kani::assume(c.iter().all(|v| v.is_finite()));
kani::assume(x1.iter().all(|v| v.is_finite()));
let mut out1 = [0.0f32; SEQ_LEN];
ssm::ssm_scan_scalar(&a_bar, &b_bar, &c, &x1, STATE_DIM, SEQ_LEN, &mut out1);
let mut x2 = x1;
let new_val: f32 = kani::any();
kani::assume(new_val.is_finite());
x2[2] = new_val;
let mut out2 = [0.0f32; SEQ_LEN];
ssm::ssm_scan_scalar(&a_bar, &b_bar, &c, &x2, STATE_DIM, SEQ_LEN, &mut out2);
assert!(
out1[0] == out2[0],
"KANI-SSM-001: output[0] changed: {} vs {}",
out1[0],
out2[0]
);
assert!(
out1[1] == out2[1],
"KANI-SSM-001: output[1] changed: {} vs {}",
out1[1],
out2[1]
);
}
#[kani::proof]
#[kani::unwind(9)]
#[kani::stub(f32::exp, stub_exp)]
#[kani::stub(f32::ln, stub_ln)]
fn verify_softplus_positive() {
const N: usize = 8;
let x: [f32; N] = kani::any();
kani::assume(x.iter().all(|v| v.is_finite()));
for i in 0..N {
let exp_x = stub_exp(x[i]);
let arg = 1.0 + exp_x;
assert!(arg > 0.0, "KANI-SSM-002: 1 + exp(x[{}]) = {} <= 0", i, arg);
}
}