use crate::error::Result;
use crate::ops::FusedOptimizerOps;
use crate::optimizer::traits::Optimizer;
use numr::autograd::GradStore;
use numr::dtype::DType;
use numr::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::{Tensor, TensorId};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct LambConfig {
pub lr: f64,
pub beta1: f64,
pub beta2: f64,
pub eps: f64,
pub weight_decay: f64,
pub max_trust_ratio: Option<f64>,
pub use_adam: bool,
}
impl Default for LambConfig {
fn default() -> Self {
Self {
lr: 1e-3,
beta1: 0.9,
beta2: 0.999,
eps: 1e-6,
weight_decay: 0.01,
max_trust_ratio: Some(10.0),
use_adam: true,
}
}
}
impl LambConfig {
pub fn lars() -> Self {
Self {
lr: 0.1,
beta1: 0.9,
beta2: 0.0,
eps: 1e-6,
weight_decay: 1e-4,
max_trust_ratio: Some(10.0),
use_adam: false,
}
}
}
struct LambState<R: Runtime> {
m: Tensor<R>,
v: Tensor<R>,
}
pub struct Lamb<R: Runtime> {
config: LambConfig,
state: HashMap<TensorId, LambState<R>>,
timestep: u64,
}
impl<R: Runtime<DType = DType>> Lamb<R> {
pub fn new(config: LambConfig) -> Self {
Self {
config,
state: HashMap::new(),
timestep: 0,
}
}
pub fn config(&self) -> &LambConfig {
&self.config
}
pub fn timestep(&self) -> u64 {
self.timestep
}
}
fn tensor_l2_norm<R, C>(client: &C, t: &Tensor<R>) -> Result<f64>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + BinaryOps<R> + UnaryOps<R> + ReduceOps<R>,
{
let sq = client.mul(t, t)?;
let sum_sq = client.sum(&sq, &[], false)?;
let val: f32 = sum_sq.item()?;
Ok((val as f64).sqrt())
}
impl<R: Runtime<DType = DType>> Optimizer<R> for Lamb<R> {
fn step<C>(
&mut self,
client: &C,
params: &mut HashMap<TensorId, Tensor<R>>,
grads: &GradStore<R>,
) -> Result<()>
where
C: RuntimeClient<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ScalarOps<R>
+ ReduceOps<R>
+ FusedOptimizerOps<R>,
{
self.timestep += 1;
let t = self.timestep;
let beta1 = self.config.beta1;
let beta2 = self.config.beta2;
let lr = self.config.lr;
let eps = self.config.eps;
let wd = self.config.weight_decay;
let bc1 = 1.0 - beta1.powi(t as i32);
let bc2 = if self.config.use_adam {
1.0 - beta2.powi(t as i32)
} else {
1.0
};
let param_ids: Vec<TensorId> = params.keys().copied().collect();
for id in param_ids {
let grad = match grads.get(id) {
Some(g) => g,
None => continue,
};
let param = params.get(&id).expect("id collected from params.keys()");
self.state.entry(id).or_insert_with(|| {
let m = Tensor::<R>::zeros(param.shape(), param.dtype(), param.device());
let v = Tensor::<R>::zeros(param.shape(), param.dtype(), param.device());
LambState { m, v }
});
let state = self
.state
.get(&id)
.expect("state was lazily initialized via or_insert_with above");
let (update, new_m, new_v) = client.fused_lamb_step(
param, grad, &state.m, &state.v, beta1, beta2, eps, wd, bc1, bc2,
)?;
let param_norm = tensor_l2_norm(client, param)?;
let update_norm = tensor_l2_norm(client, &update)?;
let trust_ratio = if param_norm > 0.0 && update_norm > 0.0 {
let ratio = param_norm / update_norm;
match self.config.max_trust_ratio {
Some(max) => ratio.min(max),
None => ratio,
}
} else {
1.0
};
let effective_lr = lr * trust_ratio;
let scaled_update = client.mul_scalar(&update, effective_lr)?;
let new_param = client.sub(param, &scaled_update)?;
let state_mut = self
.state
.get_mut(&id)
.expect("state was initialized for this id earlier in the loop");
state_mut.m = new_m;
state_mut.v = new_v;
params.insert(id, new_param);
}
Ok(())
}
fn set_lr(&mut self, lr: f64) {
self.config.lr = lr;
}
fn lr(&self) -> f64 {
self.config.lr
}
fn reset(&mut self) {
self.state.clear();
self.timestep = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::cpu_setup;
use numr::autograd::{Var, backward, var_mean, var_mul, var_sub};
use numr::runtime::cpu::CpuRuntime;
#[test]
fn test_lamb_default_config() {
let config = LambConfig::default();
assert_eq!(config.lr, 1e-3);
assert!(config.use_adam);
assert_eq!(config.max_trust_ratio, Some(10.0));
}
#[test]
fn test_lars_config() {
let config = LambConfig::lars();
assert_eq!(config.lr, 0.1);
assert!(!config.use_adam);
}
#[test]
fn test_lamb_converges() {
let (client, device) = cpu_setup();
let target = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &device);
let w_init = Tensor::<CpuRuntime>::from_slice(&[0.0f32, 0.0, 0.0, 0.0], &[2, 2], &device);
let w_id = w_init.id();
let mut params = HashMap::new();
params.insert(w_id, w_init);
let mut opt = Lamb::<CpuRuntime>::new(LambConfig {
lr: 0.1,
weight_decay: 0.0,
..Default::default()
});
let mut first_loss = 0.0f64;
let mut last_loss = 0.0f64;
for i in 0..50 {
let w_tensor = params.get(&w_id).unwrap().clone();
let w = Var::with_id(w_tensor, w_id, true);
let t = Var::new(target.clone(), false);
let diff = var_sub(&w, &t, &client).unwrap();
let sq = var_mul(&diff, &diff, &client).unwrap();
let loss = var_mean(&sq, &[0, 1], false, &client).unwrap();
let loss_val = loss.tensor().to_vec::<f32>()[0] as f64;
if i == 0 {
first_loss = loss_val;
}
last_loss = loss_val;
let grads = backward(&loss, &client).unwrap();
opt.step(&client, &mut params, &grads).unwrap();
}
assert!(
last_loss < first_loss * 0.1,
"LAMB should converge: first={first_loss} last={last_loss}"
);
}
#[test]
fn test_lars_converges() {
let (client, device) = cpu_setup();
let target = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &device);
let w_init = Tensor::<CpuRuntime>::from_slice(&[0.0f32, 0.0, 0.0, 0.0], &[2, 2], &device);
let w_id = w_init.id();
let mut params = HashMap::new();
params.insert(w_id, w_init);
let mut opt = Lamb::<CpuRuntime>::new(LambConfig {
weight_decay: 0.0,
..LambConfig::lars()
});
let mut first_loss = 0.0f64;
let mut last_loss = 0.0f64;
for i in 0..50 {
let w_tensor = params.get(&w_id).unwrap().clone();
let w = Var::with_id(w_tensor, w_id, true);
let t = Var::new(target.clone(), false);
let diff = var_sub(&w, &t, &client).unwrap();
let sq = var_mul(&diff, &diff, &client).unwrap();
let loss = var_mean(&sq, &[0, 1], false, &client).unwrap();
let loss_val = loss.tensor().to_vec::<f32>()[0] as f64;
if i == 0 {
first_loss = loss_val;
}
last_loss = loss_val;
let grads = backward(&loss, &client).unwrap();
opt.step(&client, &mut params, &grads).unwrap();
}
assert!(
last_loss < first_loss * 0.1,
"LARS should converge: first={first_loss} last={last_loss}"
);
}
#[test]
fn test_lamb_trust_ratio_clamped() {
let (client, device) = cpu_setup();
let w_tensor = Tensor::<CpuRuntime>::from_slice(&[100.0f32, 100.0], &[2], &device);
let w_id = w_tensor.id();
let grad = Tensor::<CpuRuntime>::from_slice(&[0.001f32, 0.001], &[2], &device);
let mut grads = GradStore::new();
grads.insert(w_id, grad);
let mut params = HashMap::new();
params.insert(w_id, w_tensor);
let mut opt = Lamb::<CpuRuntime>::new(LambConfig {
lr: 0.01,
weight_decay: 0.0,
max_trust_ratio: Some(10.0),
..Default::default()
});
opt.step(&client, &mut params, &grads).unwrap();
let updated = params.get(&w_id).unwrap().to_vec::<f32>();
assert!(
updated[0].is_finite(),
"update should be finite: {}",
updated[0]
);
assert!(
(updated[0] - 100.0).abs() < 1.0,
"clamped trust ratio should limit step size: {}",
updated[0]
);
}
#[test]
fn test_lamb_skips_missing_grads() {
let (client, device) = cpu_setup();
let w_tensor = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
let w_id = w_tensor.id();
let mut params = HashMap::new();
params.insert(w_id, w_tensor);
let grads = GradStore::new();
let mut opt = Lamb::<CpuRuntime>::new(LambConfig::default());
opt.step(&client, &mut params, &grads).unwrap();
let updated = params.get(&w_id).unwrap().to_vec::<f32>();
assert_eq!(updated, vec![1.0, 2.0]);
}
#[test]
fn test_lamb_reset() {
let mut opt = Lamb::<CpuRuntime>::new(LambConfig::default());
opt.reset();
assert_eq!(opt.timestep(), 0);
assert!(opt.state.is_empty());
}
}