fn step_decay_lr(initial_lr: f32, step: u32, step_decay_every: u32) -> f32 {
let factor = 1u32 << (step / step_decay_every.max(1));
initial_lr / factor as f32
}
fn exp_decay_lr(initial_lr: f32, step: u32, decay_rate: f32) -> f32 {
initial_lr * decay_rate.powi(step as i32)
}
fn main() -> tokitai_operator::Result<()> {
println!("== step-decay schedule (initial=0.1, halve every 2 steps) ==");
for step in 0..5 {
let lr = step_decay_lr(0.1, step, 2);
println!(" step {}: lr = {:.4}", step, lr);
}
assert!((step_decay_lr(0.1, 0, 2) - 0.1).abs() < 1e-6);
assert!((step_decay_lr(0.1, 2, 2) - 0.05).abs() < 1e-6);
assert!((step_decay_lr(0.1, 4, 2) - 0.025).abs() < 1e-6);
println!("\n== exp-decay schedule (initial=0.1, rate=0.9) ==");
for step in 0..5 {
let lr = exp_decay_lr(0.1, step, 0.9);
println!(" step {}: lr = {:.4}", step, lr);
}
assert!(exp_decay_lr(0.1, 0, 0.9) > exp_decay_lr(0.1, 4, 0.9));
println!("\n== momentum-vs-SGD on f(x) = x^2 (start x=4, 5 steps) ==");
let mut x_sgd = 4.0_f32;
let mut x_mom = 4.0_f32;
let mut v_mom = 0.0_f32;
let lr = 0.2_f32;
let momentum = 0.5_f32;
for step in 0..10 {
let grad = 2.0 * x_sgd;
x_sgd -= lr * grad;
let grad_m = 2.0 * x_mom;
v_mom = momentum * v_mom + grad_m;
x_mom -= lr * v_mom;
println!(
" step {}: sgd x = {:.4} (grad={:.4}), momentum x = {:.4} (v={:.4})",
step, x_sgd, grad, x_mom, v_mom
);
}
assert!(x_sgd.abs() < 1.0, "SGD should converge near 0: x={}", x_sgd);
assert!(
x_mom.abs() < 1.0,
"Momentum should converge near 0: x={}",
x_mom
);
println!("\nchapter 6 (optimization) ok");
Ok(())
}