use super::*;
#[test]
fn falsify_aw_001_decoupled_weight_decay() {
clear_graph();
let lr = 0.01;
let wd = 0.1;
let mut param_aw = Tensor::from_slice(&[5.0, -3.0, 2.0, -1.0]).requires_grad();
let loss_aw = param_aw.pow(2.0).sum();
loss_aw.backward();
let mut adamw = AdamW::new(vec![&mut param_aw], lr).weight_decay(wd);
adamw.step_with_params(&mut [&mut param_aw]);
let aw_result: Vec<f32> = param_aw.data().to_vec();
clear_graph();
let mut param_adam = Tensor::from_slice(&[5.0, -3.0, 2.0, -1.0]).requires_grad();
let loss_adam = param_adam.pow(2.0).sum();
loss_adam.backward();
let mut adam = Adam::new(vec![&mut param_adam], lr).weight_decay(wd);
adam.step_with_params(&mut [&mut param_adam]);
let adam_result: Vec<f32> = param_adam.data().to_vec();
let any_differ = aw_result
.iter()
.zip(adam_result.iter())
.any(|(&a, &b)| (a - b).abs() > 1e-7);
assert!(
any_differ,
"FALSIFIED AW-001: AdamW and Adam produced identical results with wd={wd}. \
AdamW={aw_result:?}, Adam={adam_result:?}"
);
}
#[test]
fn falsify_aw_002_second_moment_non_negative() {
clear_graph();
let mut param = Tensor::from_slice(&[1.0, -2.0, 3.0, -4.0]).requires_grad();
let mut adamw = AdamW::new(vec![&mut param], 0.001);
for _ in 0..50 {
clear_graph();
param = param.detach().requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
adamw = AdamW {
param_ids: vec![param.id()],
lr: adamw.lr,
beta1: adamw.beta1,
beta2: adamw.beta2,
eps: adamw.eps,
weight_decay: adamw.weight_decay,
m: adamw.m.clone(),
v: adamw.v.clone(),
t: adamw.t,
initialized: adamw.initialized,
};
adamw.step_with_params(&mut [&mut param]);
}
for (i, v_vec) in adamw.v.iter().enumerate() {
for (j, &v_val) in v_vec.iter().enumerate() {
assert!(
v_val >= 0.0,
"FALSIFIED AW-002: v[{i}][{j}] = {v_val} < 0 after 50 steps"
);
}
}
}
#[test]
fn falsify_aw_003_bias_correction() {
for &beta in &[0.9_f32, 0.99, 0.999] {
for t in 1..=100 {
let beta_power = beta.powi(t);
let correction = 1.0 / (1.0 - beta_power);
assert!(
correction > 1.0,
"FALSIFIED AW-003: 1/(1-{beta}^{t}) = {correction} not > 1"
);
assert!(
correction.is_finite(),
"FALSIFIED AW-003: 1/(1-{beta}^{t}) = {correction} not finite"
);
}
}
}
#[test]
fn falsify_aw_004_update_finiteness() {
clear_graph();
let mut param = Tensor::from_slice(&[1e6, -1e6, 1e-6, -1e-6]).requires_grad();
let loss = param.pow(2.0).sum();
loss.backward();
let mut adamw = AdamW::new(vec![&mut param], 0.001);
adamw.step_with_params(&mut [&mut param]);
for (i, &val) in param.data().iter().enumerate() {
assert!(
val.is_finite(),
"FALSIFIED AW-004: param[{i}] = {val} (not finite after 1 step)"
);
}
}
#[test]
fn falsify_aw_006_zero_gradient_weight_decay_only() {
let lr = 0.01_f32;
let wd = 0.1_f32;
let beta1 = 0.9_f32;
let beta2 = 0.999_f32;
let eps = 1e-8_f32;
let t = 1;
let theta = 5.0_f32;
let g = 0.0_f32;
let m = beta1 * 0.0 + (1.0 - beta1) * g;
let v = beta2 * 0.0 + (1.0 - beta2) * g * g;
let m_hat = m / (1.0 - beta1.powi(t));
let v_hat = v / (1.0 - beta2.powi(t));
let theta_new = theta - lr * wd * theta - lr * m_hat / (v_hat.sqrt() + eps);
let expected = theta * (1.0 - lr * wd);
let diff = (theta_new - expected).abs();
assert!(
diff < 1e-10,
"FALSIFIED AW-006: theta_new = {theta_new}, expected {expected} (only wd), diff = {diff}"
);
}
mod aw_proptest_falsify {
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn falsify_aw_002_prop_second_moment_non_negative(
seed in 0..1000u32,
) {
let beta2 = 0.999_f32;
let n = 4;
let mut v = vec![0.0_f32; n];
for step in 0..20 {
let g: Vec<f32> = (0..n)
.map(|i| ((i as f32 + seed as f32 + step as f32 * 13.0) * 0.37).sin() * 10.0)
.collect();
for i in 0..n {
v[i] = beta2 * v[i] + (1.0 - beta2) * g[i] * g[i];
}
}
for (i, &vi) in v.iter().enumerate() {
prop_assert!(
vi >= 0.0,
"FALSIFIED AW-002-prop: v[{}] = {} < 0",
i, vi
);
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(200))]
#[test]
fn falsify_aw_003_prop_bias_correction(
beta in 0.5f32..0.9999,
t in 1..=200i32,
) {
let beta_power = beta.powi(t);
let denom = 1.0 - beta_power;
if denom > 0.0 && denom < 1.0 {
let correction = 1.0 / denom;
prop_assert!(
correction > 1.0,
"FALSIFIED AW-003-prop: 1/(1-{}^{}) = {} not > 1",
beta, t, correction
);
prop_assert!(
correction.is_finite(),
"FALSIFIED AW-003-prop: correction not finite"
);
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn falsify_aw_006_prop_zero_gradient(
theta in -100.0f32..100.0,
lr in 0.0001f32..0.1,
wd in 0.001f32..0.5,
) {
let beta1 = 0.9_f32;
let beta2 = 0.999_f32;
let eps = 1e-8_f32;
let m = (1.0 - beta1) * 0.0;
let v = (1.0 - beta2) * 0.0;
let m_hat = m / (1.0 - beta1);
let v_hat = v / (1.0 - beta2);
let theta_new = theta - lr * wd * theta - lr * m_hat / (v_hat.sqrt() + eps);
let expected = theta * (1.0 - lr * wd);
prop_assert!(
(theta_new - expected).abs() < 1e-4,
"FALSIFIED AW-006-prop: theta_new={}, expected={}",
theta_new, expected
);
}
}
}