use std::io::{Read, Write};
use crate::autograd::{Variable, no_grad};
use crate::tensor::Result;
use super::checkpoint::{
write_tensor_state, read_tensor_state, write_f64_le, read_f64_le,
write_u32_le, read_u32_le, write_i64_le, read_i64_le,
};
use super::parameter::Parameter;
pub trait Optimizer {
fn step(&mut self) -> Result<()>;
fn zero_grad(&self);
fn set_lr(&mut self, lr: f64);
fn set_group_lr(&mut self, _group: usize, lr: f64) {
self.set_lr(lr);
}
}
struct GroupMeta {
lr: f64,
range: std::ops::Range<usize>,
}
pub trait Stateful {
fn save_state<W: Write>(&self, w: &mut W) -> Result<()>;
fn load_state<R: Read>(&mut self, r: &mut R) -> Result<()>;
fn save_state_file(&self, path: &str) -> Result<()> {
let f = std::fs::File::create(path).map_err(|e| {
crate::tensor::TensorError::new(&format!("io: {}", e))
})?;
if path.ends_with(".gz") {
let mut w = flate2::write::GzEncoder::new(f, flate2::Compression::default());
self.save_state(&mut w)?;
w.finish().map_err(|e| {
crate::tensor::TensorError::new(&format!("io: {}", e))
})?;
Ok(())
} else {
let mut w = std::io::BufWriter::new(f);
self.save_state(&mut w)
}
}
fn load_state_file(&mut self, path: &str) -> Result<()> {
let f = std::fs::File::open(path).map_err(|e| {
crate::tensor::TensorError::new(&format!("io: {}", e))
})?;
if path.ends_with(".gz") {
let mut r = flate2::read::GzDecoder::new(f);
self.load_state(&mut r)
} else {
let mut r = std::io::BufReader::new(f);
self.load_state(&mut r)
}
}
}
pub struct SGD {
params: Vec<Variable>,
lr: f64,
momentum: f64,
velocity: Vec<Option<crate::tensor::Tensor>>,
groups: Vec<GroupMeta>,
}
impl SGD {
pub fn new(params: &[Parameter], lr: f64, momentum: f64) -> Self {
let variables: Vec<Variable> = params.iter().map(|p| p.variable.clone()).collect();
let velocity = vec![None; variables.len()];
SGD {
params: variables,
lr,
momentum,
velocity,
groups: vec![],
}
}
pub fn with_groups(momentum: f64) -> SGDBuilder {
SGDBuilder { momentum, groups: vec![] }
}
pub fn lr(&self) -> f64 {
self.lr
}
fn lr_for_param(&self, i: usize) -> f64 {
for g in &self.groups {
if g.range.contains(&i) {
return g.lr;
}
}
self.lr
}
}
pub struct SGDBuilder {
momentum: f64,
groups: Vec<(Vec<Variable>, f64)>,
}
impl SGDBuilder {
pub fn group(mut self, params: &[Parameter], lr: f64) -> Self {
let vars: Vec<Variable> = params.iter().map(|p| p.variable.clone()).collect();
self.groups.push((vars, lr));
self
}
pub fn build(self) -> SGD {
let mut all_params = Vec::new();
let mut groups = Vec::new();
let base_lr = self.groups.first().map(|(_, lr)| *lr).unwrap_or(0.01);
for (vars, lr) in self.groups {
let start = all_params.len();
all_params.extend(vars);
let end = all_params.len();
groups.push(GroupMeta { lr, range: start..end });
}
let velocity = vec![None; all_params.len()];
SGD {
params: all_params,
lr: base_lr,
momentum: self.momentum,
velocity,
groups,
}
}
}
impl Optimizer for SGD {
fn step(&mut self) -> Result<()> {
no_grad(|| {
for (i, param) in self.params.iter().enumerate() {
if let Some(grad) = param.grad() {
let lr = self.lr_for_param(i);
let data = param.data().detach()?;
if self.momentum > 0.0 {
let v = match self.velocity[i].take() {
Some(v) => {
v.mul_scalar_(self.momentum)?;
v.add_(&grad)?;
v
}
None => grad.mul_scalar(1.0)?,
};
let scaled = v.mul_scalar(lr)?;
data.sub_(&scaled)?;
self.velocity[i] = Some(v);
} else {
let scaled = grad.mul_scalar(lr)?;
data.sub_(&scaled)?;
}
}
}
Ok(())
})
}
fn zero_grad(&self) {
for param in &self.params {
param.zero_grad_set_to_none();
}
}
fn set_lr(&mut self, lr: f64) {
self.lr = lr;
for g in &mut self.groups {
g.lr = lr;
}
}
fn set_group_lr(&mut self, group: usize, lr: f64) {
if let Some(g) = self.groups.get_mut(group) {
g.lr = lr;
}
}
}
impl Stateful for SGD {
fn save_state<W: Write>(&self, w: &mut W) -> Result<()> {
write_u32_le(w, self.params.len() as u32)?;
write_f64_le(w, self.lr)?;
for v in &self.velocity {
write_tensor_state(w, v.as_ref())?;
}
write_u32_le(w, self.groups.len() as u32)?;
for g in &self.groups {
write_f64_le(w, g.lr)?;
write_i64_le(w, g.range.start as i64)?;
write_i64_le(w, g.range.end as i64)?;
}
Ok(())
}
fn load_state<R: Read>(&mut self, r: &mut R) -> Result<()> {
let count = read_u32_le(r)? as usize;
if count != self.params.len() {
return Err(crate::tensor::TensorError::new(&format!(
"SGD: param count mismatch: checkpoint={} optimizer={}", count, self.params.len()
)));
}
self.lr = read_f64_le(r)?;
for (i, param) in self.params.iter().enumerate() {
let dev = param.data().device();
self.velocity[i] = read_tensor_state(r, dev)?;
}
let ng = read_u32_le(r)? as usize;
self.groups.clear();
for _ in 0..ng {
let lr = read_f64_le(r)?;
let start = read_i64_le(r)? as usize;
let end = read_i64_le(r)? as usize;
self.groups.push(GroupMeta { lr, range: start..end });
}
Ok(())
}
}
pub struct Adam {
params: Vec<Variable>,
lr: f64,
beta1: f64,
beta2: f64,
eps: f64,
m: Vec<Option<crate::tensor::Tensor>>,
v: Vec<Option<crate::tensor::Tensor>>,
t: usize,
groups: Vec<GroupMeta>,
}
impl Adam {
pub fn new(params: &[Parameter], lr: f64) -> Self {
let n = params.len();
Adam {
params: params.iter().map(|p| p.variable.clone()).collect(),
lr,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
m: vec![None; n],
v: vec![None; n],
t: 0,
groups: vec![],
}
}
pub fn with_groups() -> AdamBuilder {
AdamBuilder { beta1: 0.9, beta2: 0.999, eps: 1e-8, groups: vec![] }
}
pub fn lr(&self) -> f64 {
self.lr
}
}
pub struct AdamBuilder {
beta1: f64,
beta2: f64,
eps: f64,
groups: Vec<(Vec<Variable>, f64)>,
}
impl AdamBuilder {
pub fn betas(mut self, beta1: f64, beta2: f64) -> Self {
self.beta1 = beta1;
self.beta2 = beta2;
self
}
pub fn eps(mut self, eps: f64) -> Self { self.eps = eps; self }
pub fn group(mut self, params: &[Parameter], lr: f64) -> Self {
let vars: Vec<Variable> = params.iter().map(|p| p.variable.clone()).collect();
self.groups.push((vars, lr));
self
}
pub fn build(self) -> Adam {
let mut all_params = Vec::new();
let mut groups = Vec::new();
let base_lr = self.groups.first().map(|(_, lr)| *lr).unwrap_or(1e-3);
for (vars, lr) in self.groups {
let start = all_params.len();
all_params.extend(vars);
let end = all_params.len();
groups.push(GroupMeta { lr, range: start..end });
}
let n = all_params.len();
Adam {
params: all_params,
lr: base_lr,
beta1: self.beta1,
beta2: self.beta2,
eps: self.eps,
m: vec![None; n],
v: vec![None; n],
t: 0,
groups,
}
}
}
impl Optimizer for Adam {
fn step(&mut self) -> Result<()> {
self.adam_update(0.0)
}
fn zero_grad(&self) {
for param in &self.params {
param.zero_grad_set_to_none();
}
}
fn set_lr(&mut self, lr: f64) {
self.lr = lr;
for g in &mut self.groups {
g.lr = lr;
}
}
fn set_group_lr(&mut self, group: usize, lr: f64) {
if let Some(g) = self.groups.get_mut(group) {
g.lr = lr;
}
}
}
impl Adam {
fn adam_update(&mut self, weight_decay: f64) -> Result<()> {
self.t += 1;
no_grad(|| {
let effective_groups: Vec<(f64, std::ops::Range<usize>)> = if self.groups.is_empty() {
vec![(self.lr, 0..self.params.len())]
} else {
self.groups.iter().map(|g| (g.lr, g.range.clone())).collect()
};
for (lr, range) in &effective_groups {
let mut p_tensors = Vec::new();
let mut g_tensors = Vec::new();
let mut m_tensors = Vec::new();
let mut v_tensors = Vec::new();
for i in range.clone() {
if let Some(grad) = self.params[i].grad() {
if self.m[i].is_none() {
self.m[i] = Some(crate::tensor::Tensor::zeros_like(&grad)?);
}
if self.v[i].is_none() {
self.v[i] = Some(crate::tensor::Tensor::zeros_like(&grad)?);
}
p_tensors.push(self.params[i].data());
g_tensors.push(grad);
m_tensors.push(self.m[i].as_ref().unwrap().clone());
v_tensors.push(self.v[i].as_ref().unwrap().clone());
}
}
if !p_tensors.is_empty() {
crate::tensor::Tensor::fused_adamw_(
&p_tensors, &g_tensors, &m_tensors, &v_tensors,
*lr, self.beta1, self.beta2, self.eps,
weight_decay, self.t as i64, None, None,
)?;
}
}
Ok(())
})
}
}
impl Stateful for Adam {
fn save_state<W: Write>(&self, w: &mut W) -> Result<()> {
write_u32_le(w, self.params.len() as u32)?;
write_f64_le(w, self.lr)?;
write_i64_le(w, self.t as i64)?;
for i in 0..self.params.len() {
write_tensor_state(w, self.m[i].as_ref())?;
write_tensor_state(w, self.v[i].as_ref())?;
}
write_u32_le(w, self.groups.len() as u32)?;
for g in &self.groups {
write_f64_le(w, g.lr)?;
write_i64_le(w, g.range.start as i64)?;
write_i64_le(w, g.range.end as i64)?;
}
Ok(())
}
fn load_state<R: Read>(&mut self, r: &mut R) -> Result<()> {
let count = read_u32_le(r)? as usize;
if count != self.params.len() {
return Err(crate::tensor::TensorError::new(&format!(
"Adam: param count mismatch: checkpoint={} optimizer={}", count, self.params.len()
)));
}
self.lr = read_f64_le(r)?;
self.t = read_i64_le(r)? as usize;
for i in 0..self.params.len() {
let dev = self.params[i].data().device();
self.m[i] = read_tensor_state(r, dev)?;
self.v[i] = read_tensor_state(r, dev)?;
}
let ng = read_u32_le(r)? as usize;
self.groups.clear();
for _ in 0..ng {
let lr = read_f64_le(r)?;
let start = read_i64_le(r)? as usize;
let end = read_i64_le(r)? as usize;
self.groups.push(GroupMeta { lr, range: start..end });
}
Ok(())
}
}
pub struct AdamW {
adam: Adam,
weight_decay: f64,
}
impl AdamW {
pub fn new(params: &[Parameter], lr: f64, weight_decay: f64) -> Self {
AdamW {
adam: Adam::new(params, lr),
weight_decay,
}
}
pub fn with_groups(weight_decay: f64) -> AdamWBuilder {
AdamWBuilder { beta1: 0.9, beta2: 0.999, eps: 1e-8, weight_decay, groups: vec![] }
}
pub fn lr(&self) -> f64 {
self.adam.lr
}
}
pub struct AdamWBuilder {
beta1: f64,
beta2: f64,
eps: f64,
weight_decay: f64,
groups: Vec<(Vec<Variable>, f64)>,
}
impl AdamWBuilder {
pub fn betas(mut self, beta1: f64, beta2: f64) -> Self {
self.beta1 = beta1;
self.beta2 = beta2;
self
}
pub fn eps(mut self, eps: f64) -> Self { self.eps = eps; self }
pub fn group(mut self, params: &[Parameter], lr: f64) -> Self {
let vars: Vec<Variable> = params.iter().map(|p| p.variable.clone()).collect();
self.groups.push((vars, lr));
self
}
pub fn build(self) -> AdamW {
let mut all_params = Vec::new();
let mut groups = Vec::new();
let base_lr = self.groups.first().map(|(_, lr)| *lr).unwrap_or(1e-3);
for (vars, lr) in self.groups {
let start = all_params.len();
all_params.extend(vars);
let end = all_params.len();
groups.push(GroupMeta { lr, range: start..end });
}
let n = all_params.len();
AdamW {
adam: Adam {
params: all_params,
lr: base_lr,
beta1: self.beta1,
beta2: self.beta2,
eps: self.eps,
m: vec![None; n],
v: vec![None; n],
t: 0,
groups,
},
weight_decay: self.weight_decay,
}
}
}
impl Optimizer for AdamW {
fn step(&mut self) -> Result<()> {
self.adam.adam_update(self.weight_decay)
}
fn zero_grad(&self) {
self.adam.zero_grad()
}
fn set_lr(&mut self, lr: f64) {
self.adam.set_lr(lr);
}
fn set_group_lr(&mut self, group: usize, lr: f64) {
self.adam.set_group_lr(group, lr);
}
}
impl Stateful for AdamW {
fn save_state<W: Write>(&self, w: &mut W) -> Result<()> {
write_f64_le(w, self.weight_decay)?;
self.adam.save_state(w)
}
fn load_state<R: Read>(&mut self, r: &mut R) -> Result<()> {
self.weight_decay = read_f64_le(r)?;
self.adam.load_state(r)
}
}
pub struct RMSprop {
params: Vec<Variable>,
lr: f64,
alpha: f64,
eps: f64,
weight_decay: f64,
momentum: f64,
v: Vec<Option<crate::tensor::Tensor>>,
buf: Vec<Option<crate::tensor::Tensor>>,
groups: Vec<GroupMeta>,
}
impl RMSprop {
pub fn new(params: &[Parameter], lr: f64) -> Self {
let n = params.len();
RMSprop {
params: params.iter().map(|p| p.variable.clone()).collect(),
lr,
alpha: 0.99,
eps: 1e-8,
weight_decay: 0.0,
momentum: 0.0,
v: vec![None; n],
buf: vec![None; n],
groups: vec![],
}
}
pub fn builder(params: &[Parameter], lr: f64) -> RMSpropBuilder {
RMSpropBuilder {
params: params.to_vec(),
lr,
alpha: 0.99,
eps: 1e-8,
weight_decay: 0.0,
momentum: 0.0,
}
}
pub fn lr(&self) -> f64 {
self.lr
}
fn lr_for_param(&self, i: usize) -> f64 {
for g in &self.groups {
if g.range.contains(&i) {
return g.lr;
}
}
self.lr
}
}
pub struct RMSpropBuilder {
params: Vec<Parameter>,
lr: f64,
alpha: f64,
eps: f64,
weight_decay: f64,
momentum: f64,
}
impl RMSpropBuilder {
pub fn alpha(mut self, alpha: f64) -> Self { self.alpha = alpha; self }
pub fn eps(mut self, eps: f64) -> Self { self.eps = eps; self }
pub fn weight_decay(mut self, wd: f64) -> Self { self.weight_decay = wd; self }
pub fn momentum(mut self, momentum: f64) -> Self { self.momentum = momentum; self }
pub fn build(self) -> RMSprop {
let n = self.params.len();
RMSprop {
params: self.params.iter().map(|p| p.variable.clone()).collect(),
lr: self.lr,
alpha: self.alpha,
eps: self.eps,
weight_decay: self.weight_decay,
momentum: self.momentum,
v: vec![None; n],
buf: vec![None; n],
groups: vec![],
}
}
}
impl Optimizer for RMSprop {
fn step(&mut self) -> Result<()> {
no_grad(|| {
for (i, param) in self.params.iter().enumerate() {
if let Some(mut grad) = param.grad() {
let lr = self.lr_for_param(i);
let data = param.data().detach()?;
if self.weight_decay > 0.0 {
grad = grad.add(&data.mul_scalar(self.weight_decay)?)?;
}
let grad_sq = grad.mul(&grad)?;
let v = match self.v[i].take() {
Some(v) => {
v.mul_scalar_(self.alpha)?;
let scaled = grad_sq.mul_scalar(1.0 - self.alpha)?;
v.add_(&scaled)?;
v
}
None => grad_sq.mul_scalar(1.0 - self.alpha)?,
};
let denom = v.sqrt()?.add_scalar(self.eps)?;
let update = grad.div(&denom)?;
if self.momentum > 0.0 {
let b = match self.buf[i].take() {
Some(b) => {
b.mul_scalar_(self.momentum)?;
b.add_(&update)?;
b
}
None => update.mul_scalar(1.0)?,
};
let scaled = b.mul_scalar(lr)?;
data.sub_(&scaled)?;
self.buf[i] = Some(b);
} else {
let scaled = update.mul_scalar(lr)?;
data.sub_(&scaled)?;
}
self.v[i] = Some(v);
}
}
Ok(())
})
}
fn zero_grad(&self) {
for param in &self.params {
param.zero_grad_set_to_none();
}
}
fn set_lr(&mut self, lr: f64) {
self.lr = lr;
for g in &mut self.groups {
g.lr = lr;
}
}
fn set_group_lr(&mut self, group: usize, lr: f64) {
if let Some(g) = self.groups.get_mut(group) {
g.lr = lr;
}
}
}
impl Stateful for RMSprop {
fn save_state<W: Write>(&self, w: &mut W) -> Result<()> {
write_u32_le(w, self.params.len() as u32)?;
write_f64_le(w, self.lr)?;
write_f64_le(w, self.alpha)?;
write_f64_le(w, self.eps)?;
write_f64_le(w, self.weight_decay)?;
write_f64_le(w, self.momentum)?;
for i in 0..self.params.len() {
write_tensor_state(w, self.v[i].as_ref())?;
write_tensor_state(w, self.buf[i].as_ref())?;
}
write_u32_le(w, self.groups.len() as u32)?;
for g in &self.groups {
write_f64_le(w, g.lr)?;
write_i64_le(w, g.range.start as i64)?;
write_i64_le(w, g.range.end as i64)?;
}
Ok(())
}
fn load_state<R: Read>(&mut self, r: &mut R) -> Result<()> {
let count = read_u32_le(r)? as usize;
if count != self.params.len() {
return Err(crate::tensor::TensorError::new(&format!(
"RMSprop: param count mismatch: checkpoint={} optimizer={}", count, self.params.len()
)));
}
self.lr = read_f64_le(r)?;
self.alpha = read_f64_le(r)?;
self.eps = read_f64_le(r)?;
self.weight_decay = read_f64_le(r)?;
self.momentum = read_f64_le(r)?;
for i in 0..self.params.len() {
let dev = self.params[i].data().device();
self.v[i] = read_tensor_state(r, dev)?;
self.buf[i] = read_tensor_state(r, dev)?;
}
let ng = read_u32_le(r)? as usize;
self.groups.clear();
for _ in 0..ng {
let lr = read_f64_le(r)?;
let start = read_i64_le(r)? as usize;
let end = read_i64_le(r)? as usize;
self.groups.push(GroupMeta { lr, range: start..end });
}
Ok(())
}
}
pub struct Adagrad {
params: Vec<Variable>,
lr: f64,
eps: f64,
weight_decay: f64,
lr_decay: f64,
state_sum: Vec<Option<crate::tensor::Tensor>>,
step_count: u64,
}
pub struct AdagradBuilder {
params: Vec<Parameter>,
lr: f64,
eps: f64,
weight_decay: f64,
lr_decay: f64,
}
impl AdagradBuilder {
pub fn eps(mut self, eps: f64) -> Self { self.eps = eps; self }
pub fn weight_decay(mut self, wd: f64) -> Self { self.weight_decay = wd; self }
pub fn lr_decay(mut self, lr_decay: f64) -> Self { self.lr_decay = lr_decay; self }
pub fn build(self) -> Adagrad {
let n = self.params.len();
Adagrad {
params: self.params.iter().map(|p| p.variable.clone()).collect(),
lr: self.lr, eps: self.eps,
weight_decay: self.weight_decay, lr_decay: self.lr_decay,
state_sum: vec![None; n],
step_count: 0,
}
}
}
impl Adagrad {
pub fn new(params: &[Parameter], lr: f64) -> Self {
let n = params.len();
Adagrad {
params: params.iter().map(|p| p.variable.clone()).collect(),
lr, eps: 1e-10, weight_decay: 0.0, lr_decay: 0.0,
state_sum: vec![None; n],
step_count: 0,
}
}
pub fn builder(params: &[Parameter], lr: f64) -> AdagradBuilder {
AdagradBuilder {
params: params.to_vec(), lr, eps: 1e-10, weight_decay: 0.0, lr_decay: 0.0,
}
}
pub fn lr(&self) -> f64 { self.lr }
}
impl Optimizer for Adagrad {
fn step(&mut self) -> Result<()> {
self.step_count += 1;
let clr = self.lr / (1.0 + (self.step_count - 1) as f64 * self.lr_decay);
no_grad(|| {
for (i, param) in self.params.iter().enumerate() {
if let Some(mut grad) = param.grad() {
let data = param.data().detach()?;
if self.weight_decay > 0.0 {
grad = grad.add(&data.mul_scalar(self.weight_decay)?)?;
}
let grad2 = grad.mul(&grad)?;
let ss = match self.state_sum[i].take() {
Some(ss) => ss.add(&grad2)?,
None => grad2,
};
let update = grad.div(&ss.sqrt()?.add_scalar(self.eps)?)?.mul_scalar(clr)?;
data.sub_(&update)?;
self.state_sum[i] = Some(ss);
}
}
Ok(())
})
}
fn zero_grad(&self) {
for p in &self.params { p.zero_grad_set_to_none(); }
}
fn set_lr(&mut self, lr: f64) { self.lr = lr; }
}
pub struct RAdam {
params: Vec<Variable>,
lr: f64,
beta1: f64,
beta2: f64,
eps: f64,
weight_decay: f64,
m: Vec<Option<crate::tensor::Tensor>>,
v: Vec<Option<crate::tensor::Tensor>>,
step_count: u64,
}
impl RAdam {
pub fn new(params: &[Parameter], lr: f64) -> Self {
let n = params.len();
RAdam {
params: params.iter().map(|p| p.variable.clone()).collect(),
lr, beta1: 0.9, beta2: 0.999, eps: 1e-8, weight_decay: 0.0,
m: vec![None; n], v: vec![None; n], step_count: 0,
}
}
pub fn lr(&self) -> f64 { self.lr }
}
impl Optimizer for RAdam {
fn step(&mut self) -> Result<()> {
self.step_count += 1;
let t = self.step_count as f64;
let b1 = self.beta1;
let b2 = self.beta2;
let b1t = b1.powf(t);
let b2t = b2.powf(t);
let rho_inf = 2.0 / (1.0 - b2) - 1.0;
let rho_t = rho_inf - 2.0 * t * b2t / (1.0 - b2t);
no_grad(|| {
for (i, param) in self.params.iter().enumerate() {
if let Some(mut grad) = param.grad() {
let data = param.data().detach()?;
if self.weight_decay > 0.0 {
grad = grad.add(&data.mul_scalar(self.weight_decay)?)?;
}
let m_new = match self.m[i].take() {
Some(m) => m.mul_scalar(b1)?.add(&grad.mul_scalar(1.0 - b1)?)?,
None => grad.mul_scalar(1.0 - b1)?,
};
let grad2 = grad.mul(&grad)?;
let v_new = match self.v[i].take() {
Some(v) => v.mul_scalar(b2)?.add(&grad2.mul_scalar(1.0 - b2)?)?,
None => grad2.mul_scalar(1.0 - b2)?,
};
let m_hat = m_new.mul_scalar(1.0 / (1.0 - b1t))?;
if rho_t > 5.0 {
let v_hat = v_new.mul_scalar(1.0 / (1.0 - b2t))?;
let rect = ((rho_t - 4.0) * (rho_t - 2.0) * rho_inf /
((rho_inf - 4.0) * (rho_inf - 2.0) * rho_t)).sqrt();
let update = m_hat.div(&v_hat.sqrt()?.add_scalar(self.eps)?)?.mul_scalar(self.lr * rect)?;
data.sub_(&update)?;
} else {
let update = m_hat.mul_scalar(self.lr)?;
data.sub_(&update)?;
}
self.m[i] = Some(m_new);
self.v[i] = Some(v_new);
}
}
Ok(())
})
}
fn zero_grad(&self) {
for p in &self.params { p.zero_grad_set_to_none(); }
}
fn set_lr(&mut self, lr: f64) { self.lr = lr; }
}
pub struct NAdam {
params: Vec<Variable>,
lr: f64,
beta1: f64,
beta2: f64,
eps: f64,
weight_decay: f64,
m: Vec<Option<crate::tensor::Tensor>>,
v: Vec<Option<crate::tensor::Tensor>>,
step_count: u64,
}
impl NAdam {
pub fn new(params: &[Parameter], lr: f64) -> Self {
let n = params.len();
NAdam {
params: params.iter().map(|p| p.variable.clone()).collect(),
lr, beta1: 0.9, beta2: 0.999, eps: 1e-8, weight_decay: 0.0,
m: vec![None; n], v: vec![None; n], step_count: 0,
}
}
pub fn lr(&self) -> f64 { self.lr }
}
impl Optimizer for NAdam {
fn step(&mut self) -> Result<()> {
self.step_count += 1;
let t = self.step_count as f64;
let b1 = self.beta1;
let b2 = self.beta2;
let b1t = b1.powf(t);
let b2t = b2.powf(t);
let b1t1 = b1.powf(t + 1.0);
no_grad(|| {
for (i, param) in self.params.iter().enumerate() {
if let Some(mut grad) = param.grad() {
let data = param.data().detach()?;
if self.weight_decay > 0.0 {
grad = grad.add(&data.mul_scalar(self.weight_decay)?)?;
}
let m_new = match self.m[i].take() {
Some(m) => m.mul_scalar(b1)?.add(&grad.mul_scalar(1.0 - b1)?)?,
None => grad.mul_scalar(1.0 - b1)?,
};
let grad2 = grad.mul(&grad)?;
let v_new = match self.v[i].take() {
Some(v) => v.mul_scalar(b2)?.add(&grad2.mul_scalar(1.0 - b2)?)?,
None => grad2.mul_scalar(1.0 - b2)?,
};
let m_hat = m_new.mul_scalar(b1 / (1.0 - b1t1))?
.add(&grad.mul_scalar((1.0 - b1) / (1.0 - b1t))?)?;
let v_hat = v_new.mul_scalar(1.0 / (1.0 - b2t))?;
let update = m_hat.div(&v_hat.sqrt()?.add_scalar(self.eps)?)?.mul_scalar(self.lr)?;
data.sub_(&update)?;
self.m[i] = Some(m_new);
self.v[i] = Some(v_new);
}
}
Ok(())
})
}
fn zero_grad(&self) {
for p in &self.params { p.zero_grad_set_to_none(); }
}
fn set_lr(&mut self, lr: f64) { self.lr = lr; }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{Tensor, TensorOptions};
fn make_param(name: &str, shape: &[i64]) -> Parameter {
let t = Tensor::randn(shape, TensorOptions {
dtype: crate::tensor::DType::Float32,
device: crate::tensor::test_device(),
}).unwrap();
Parameter::new(t, name)
}
#[test]
fn test_adam_backward_compat() {
let p = make_param("w", &[3, 2]);
let mut opt = Adam::new(std::slice::from_ref(&p), 0.01);
let x = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0], &[1, 3], crate::tensor::test_device()).unwrap(),
false,
);
let y = x.matmul(&p.variable).unwrap();
let loss = y.sum().unwrap();
loss.backward().unwrap();
let before = p.variable.data().to_f32_vec().unwrap();
opt.step().unwrap();
let after = p.variable.data().to_f32_vec().unwrap();
assert_ne!(before, after, "params should change after step");
}
#[test]
fn test_adam_two_groups_different_lr() {
let p1 = make_param("w1", &[3, 2]);
let p2 = make_param("w2", &[3, 2]);
let mut opt = Adam::with_groups()
.group(std::slice::from_ref(&p1), 0.1)
.group(std::slice::from_ref(&p2), 1e-10)
.build();
let x = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0], &[1, 3], crate::tensor::test_device()).unwrap(),
false,
);
let y1 = x.matmul(&p1.variable).unwrap();
let y2 = x.matmul(&p2.variable).unwrap();
let loss = y1.add(&y2).unwrap().sum().unwrap();
loss.backward().unwrap();
let p1_before = p1.variable.data().to_f32_vec().unwrap();
let p2_before = p2.variable.data().to_f32_vec().unwrap();
opt.step().unwrap();
let p1_after = p1.variable.data().to_f32_vec().unwrap();
let p2_after = p2.variable.data().to_f32_vec().unwrap();
let p1_delta: f64 = p1_before.iter().zip(&p1_after)
.map(|(a, b)| (a - b).abs() as f64).sum();
let p2_delta: f64 = p2_before.iter().zip(&p2_after)
.map(|(a, b)| (a - b).abs() as f64).sum();
assert!(p1_delta > p2_delta * 1e6,
"high-LR group should move much more: p1_delta={}, p2_delta={}", p1_delta, p2_delta);
}
#[test]
fn test_set_group_lr_changes_one_group() {
let p1 = make_param("w1", &[3, 2]);
let p2 = make_param("w2", &[3, 2]);
let mut opt = Adam::with_groups()
.group(std::slice::from_ref(&p1), 0.01)
.group(std::slice::from_ref(&p2), 0.01)
.build();
opt.set_group_lr(1, 0.99);
assert!((opt.groups[0].lr - 0.01).abs() < 1e-12);
assert!((opt.groups[1].lr - 0.99).abs() < 1e-12);
}
#[test]
fn test_set_lr_changes_all_groups() {
let p1 = make_param("w1", &[3, 2]);
let p2 = make_param("w2", &[3, 2]);
let mut opt = Adam::with_groups()
.group(std::slice::from_ref(&p1), 0.01)
.group(std::slice::from_ref(&p2), 0.05)
.build();
opt.set_lr(0.42);
assert!((opt.lr - 0.42).abs() < 1e-12);
assert!((opt.groups[0].lr - 0.42).abs() < 1e-12);
assert!((opt.groups[1].lr - 0.42).abs() < 1e-12);
}
#[test]
fn test_frozen_params_in_group_no_crash() {
let p1 = make_param("w1", &[3, 2]);
let p2 = make_param("w2", &[3, 2]);
p1.freeze().unwrap();
let mut opt = Adam::with_groups()
.group(&[p1, p2.clone()], 0.01)
.build();
let x = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0], &[1, 3], crate::tensor::test_device()).unwrap(),
false,
);
let y = x.matmul(&p2.variable).unwrap();
let loss = y.sum().unwrap();
loss.backward().unwrap();
opt.step().unwrap();
opt.zero_grad();
}
#[test]
fn test_adam_save_load_with_groups() {
let p1 = make_param("w1", &[3, 2]);
let p2 = make_param("w2", &[3, 2]);
let mut opt = Adam::with_groups()
.group(std::slice::from_ref(&p1), 0.01)
.group(std::slice::from_ref(&p2), 0.05)
.build();
let x = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0], &[1, 3], crate::tensor::test_device()).unwrap(),
false,
);
let y1 = x.matmul(&p1.variable).unwrap();
let y2 = x.matmul(&p2.variable).unwrap();
let loss = y1.add(&y2).unwrap().sum().unwrap();
loss.backward().unwrap();
opt.step().unwrap();
let mut buf = Vec::new();
opt.save_state(&mut buf).unwrap();
let mut opt2 = Adam::with_groups()
.group(std::slice::from_ref(&p1), 0.99)
.group(std::slice::from_ref(&p2), 0.99)
.build();
let mut cursor = std::io::Cursor::new(&buf);
opt2.load_state(&mut cursor).unwrap();
assert_eq!(opt2.t, opt.t);
assert!((opt2.groups[0].lr - 0.01).abs() < 1e-12);
assert!((opt2.groups[1].lr - 0.05).abs() < 1e-12);
}
#[test]
fn test_fused_adam_numerical_correctness() {
let param = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[4], crate::tensor::test_device()).unwrap();
let grad = Tensor::from_f32(&[0.1, 0.2, 0.3, 0.4], &[4], crate::tensor::test_device()).unwrap();
let m = Tensor::zeros(&[4], crate::tensor::test_opts()).unwrap();
let v = Tensor::zeros(&[4], crate::tensor::test_opts()).unwrap();
let lr = 0.001;
let beta1 = 0.9;
let beta2 = 0.999;
let eps = 1e-8;
let step: i64 = 1;
param.adam_step(&grad, &m, &v, lr, beta1, beta2, eps, 0.0, step).unwrap();
let p_data = param.to_f32_vec().unwrap();
let m_data = m.to_f32_vec().unwrap();
let v_data = v.to_f32_vec().unwrap();
for (i, &g) in [0.1f32, 0.2, 0.3, 0.4].iter().enumerate() {
assert!((m_data[i] - 0.1 * g).abs() < 1e-6,
"m[{}]: got {}, expected {}", i, m_data[i], 0.1 * g);
}
for (i, &g) in [0.1f32, 0.2, 0.3, 0.4].iter().enumerate() {
assert!((v_data[i] - 0.001 * g * g).abs() < 1e-9,
"v[{}]: got {}, expected {}", i, v_data[i], 0.001 * g * g);
}
let orig = [1.0f32, 2.0, 3.0, 4.0];
for (i, &o) in orig.iter().enumerate() {
assert!((p_data[i] - (o - lr as f32)).abs() < 1e-5,
"p[{}]: got {}, expected ~{}", i, p_data[i], o - lr as f32);
}
}
#[test]
fn test_fused_adamw_weight_decay() {
let param = Tensor::from_f32(&[1.0, 2.0], &[2], crate::tensor::test_device()).unwrap();
let grad = Tensor::from_f32(&[0.1, 0.1], &[2], crate::tensor::test_device()).unwrap();
let m = Tensor::zeros(&[2], crate::tensor::test_opts()).unwrap();
let v = Tensor::zeros(&[2], crate::tensor::test_opts()).unwrap();
let lr = 0.001;
let wd = 0.01;
param.adam_step(&grad, &m, &v, lr, 0.9, 0.999, 1e-8, wd, 1).unwrap();
let p_data = param.to_f32_vec().unwrap();
assert!(p_data[0] < 1.0, "p[0] should decrease: got {}", p_data[0]);
assert!(p_data[1] < 2.0, "p[1] should decrease: got {}", p_data[1]);
let decay_0 = 1.0 - p_data[0] as f64;
let decay_1 = 2.0 - p_data[1] as f64;
assert!(decay_1 > decay_0, "larger param should decay more: d0={}, d1={}", decay_0, decay_1);
}
#[test]
fn test_fused_adam_multi_step_convergence() {
let param = Tensor::from_f32(&[5.0], &[1], crate::tensor::test_device()).unwrap();
let grad = Tensor::from_f32(&[1.0], &[1], crate::tensor::test_device()).unwrap();
let m = Tensor::zeros(&[1], crate::tensor::test_opts()).unwrap();
let v = Tensor::zeros(&[1], crate::tensor::test_opts()).unwrap();
for step in 1..=10 {
param.adam_step(&grad, &m, &v, 0.01, 0.9, 0.999, 1e-8, 0.0, step).unwrap();
}
let m_data = m.to_f32_vec().unwrap();
let p_data = param.to_f32_vec().unwrap();
assert!((m_data[0] - 0.6513).abs() < 0.01,
"m after 10 steps: got {}", m_data[0]);
assert!(v.to_f32_vec().unwrap()[0] > 0.0, "v should accumulate");
assert!(p_data[0] < 5.0, "param should decrease: got {}", p_data[0]);
}
#[test]
fn test_rmsprop_basic() {
let p = make_param("w", &[3, 2]);
let mut opt = RMSprop::new(std::slice::from_ref(&p), 0.01);
let x = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0], &[1, 3], crate::tensor::test_device()).unwrap(),
false,
);
let y = x.matmul(&p.variable).unwrap();
let loss = y.sum().unwrap();
loss.backward().unwrap();
let before = p.variable.data().to_f32_vec().unwrap();
opt.step().unwrap();
let after = p.variable.data().to_f32_vec().unwrap();
assert_ne!(before, after, "params should change after step");
}
#[test]
fn test_rmsprop_with_momentum() {
let p = make_param("w", &[3, 2]);
let mut opt = RMSprop::builder(std::slice::from_ref(&p), 0.01)
.momentum(0.9)
.build();
let x = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0], &[1, 3], crate::tensor::test_device()).unwrap(),
false,
);
for _ in 0..2 {
opt.zero_grad();
let y = x.matmul(&p.variable).unwrap();
let loss = y.sum().unwrap();
loss.backward().unwrap();
opt.step().unwrap();
}
let data = p.variable.data().to_f32_vec().unwrap();
assert!(data.iter().any(|&v| v.abs() > 0.0), "params should be non-zero");
}
#[test]
fn test_rmsprop_with_weight_decay() {
let init = [0.5f32, -0.3, 0.1, 0.8, -0.2, 0.4];
let dev = crate::tensor::test_device();
let p1 = Parameter::new(
Tensor::from_f32(&init, &[3, 2], dev).unwrap(), "w1");
let p2 = Parameter::new(
Tensor::from_f32(&init, &[3, 2], dev).unwrap(), "w2");
let mut opt_wd = RMSprop::builder(std::slice::from_ref(&p1), 0.01)
.weight_decay(0.1)
.build();
let mut opt_plain = RMSprop::new(std::slice::from_ref(&p2), 0.01);
let x = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0], &[1, 3], dev).unwrap(),
false,
);
for _ in 0..10 {
opt_wd.zero_grad();
let y1 = x.matmul(&p1.variable).unwrap();
y1.sum().unwrap().backward().unwrap();
opt_wd.step().unwrap();
opt_plain.zero_grad();
let y2 = x.matmul(&p2.variable).unwrap();
y2.sum().unwrap().backward().unwrap();
opt_plain.step().unwrap();
}
let d1 = p1.variable.data().to_f32_vec().unwrap();
let d2 = p2.variable.data().to_f32_vec().unwrap();
assert_ne!(d1, d2, "weight decay should produce different results after 10 steps");
}
#[test]
fn test_rmsprop_convergence() {
let p = Parameter::new(
Tensor::from_f32(&[5.0], &[1], crate::tensor::test_device()).unwrap(),
"x",
);
let mut opt = RMSprop::new(std::slice::from_ref(&p), 0.1);
let x = Variable::new(
Tensor::from_f32(&[1.0], &[1], crate::tensor::test_device()).unwrap(),
false,
);
let y = x.mul(&p.variable).unwrap();
let loss = y.mul(&y).unwrap().sum().unwrap();
loss.backward().unwrap();
opt.step().unwrap();
let val = p.variable.data().to_f32_vec().unwrap()[0];
assert!(val < 5.0, "param should decrease from 5.0, got {}", val);
}
#[test]
fn test_rmsprop_save_load() {
let p = make_param("w", &[3, 2]);
let mut opt = RMSprop::builder(std::slice::from_ref(&p), 0.01)
.momentum(0.9)
.alpha(0.95)
.weight_decay(0.01)
.build();
let x = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0], &[1, 3], crate::tensor::test_device()).unwrap(),
false,
);
let y = x.matmul(&p.variable).unwrap();
y.sum().unwrap().backward().unwrap();
opt.step().unwrap();
let mut buf = Vec::new();
opt.save_state(&mut buf).unwrap();
let mut opt2 = RMSprop::builder(std::slice::from_ref(&p), 0.99)
.build();
let mut cursor = std::io::Cursor::new(&buf);
opt2.load_state(&mut cursor).unwrap();
assert!((opt2.lr - 0.01).abs() < 1e-12);
assert!((opt2.alpha - 0.95).abs() < 1e-12);
assert!((opt2.momentum - 0.9).abs() < 1e-12);
assert!((opt2.weight_decay - 0.01).abs() < 1e-12);
}
#[test]
fn test_rmsprop_builder_defaults() {
let p = make_param("w", &[2]);
let opt = RMSprop::new(std::slice::from_ref(&p), 0.01);
assert!((opt.alpha - 0.99).abs() < 1e-12);
assert!((opt.eps - 1e-8).abs() < 1e-15);
assert!((opt.weight_decay).abs() < 1e-12);
assert!((opt.momentum).abs() < 1e-12);
}
#[test]
fn test_rmsprop_frozen_params() {
let p1 = make_param("w1", &[3, 2]);
let p2 = make_param("w2", &[3, 2]);
p1.freeze().unwrap();
let mut opt = RMSprop::new(&[p1, p2.clone()], 0.01);
let x = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0], &[1, 3], crate::tensor::test_device()).unwrap(),
false,
);
let y = x.matmul(&p2.variable).unwrap();
y.sum().unwrap().backward().unwrap();
opt.step().unwrap();
opt.zero_grad();
}
#[test]
fn test_adagrad_steps() {
let p = make_param("w", &[1]);
let before = p.variable.data().item().unwrap();
let mut opt = Adagrad::new(std::slice::from_ref(&p), 0.5);
let x = Variable::new(
Tensor::from_f32(&[2.0], &[1], crate::tensor::test_device()).unwrap(), false,
);
let loss = x.mul(&p.variable).unwrap().sum().unwrap();
loss.backward().unwrap();
opt.step().unwrap();
let after = p.variable.data().item().unwrap();
assert!((after - before).abs() > 1e-6, "Adagrad step should change parameter");
}
#[test]
fn test_radam_steps() {
let p = make_param("w", &[1]);
let before = p.variable.data().item().unwrap();
let mut opt = RAdam::new(std::slice::from_ref(&p), 0.01);
let x = Variable::new(
Tensor::from_f32(&[2.0], &[1], crate::tensor::test_device()).unwrap(), false,
);
let loss = x.mul(&p.variable).unwrap().sum().unwrap();
loss.backward().unwrap();
opt.step().unwrap();
let after = p.variable.data().item().unwrap();
assert!((after - before).abs() > 1e-6, "RAdam step should change parameter");
}
#[test]
fn test_nadam_steps() {
let p = make_param("w", &[1]);
let before = p.variable.data().item().unwrap();
let mut opt = NAdam::new(std::slice::from_ref(&p), 0.01);
let x = Variable::new(
Tensor::from_f32(&[2.0], &[1], crate::tensor::test_device()).unwrap(), false,
);
let loss = x.mul(&p.variable).unwrap().sum().unwrap();
loss.backward().unwrap();
opt.step().unwrap();
let after = p.variable.data().item().unwrap();
assert!((after - before).abs() > 1e-6, "NAdam step should change parameter");
}
#[test]
fn test_adam_zero_lr_no_param_change() {
let p = make_param("w", &[3, 2]);
let mut opt = Adam::new(std::slice::from_ref(&p), 0.0);
let x = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0], &[1, 3], crate::tensor::test_device()).unwrap(),
false,
);
let before = p.variable.data().to_f32_vec().unwrap();
let y = x.matmul(&p.variable).unwrap();
y.sum().unwrap().backward().unwrap();
opt.step().unwrap();
let after = p.variable.data().to_f32_vec().unwrap();
assert_eq!(before, after, "lr=0 should leave parameters unchanged");
}
#[test]
fn test_adam_very_small_lr_no_nan() {
let p = make_param("w", &[4, 3]);
let mut opt = Adam::new(std::slice::from_ref(&p), 1e-30);
let x = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[1, 4], crate::tensor::test_device()).unwrap(),
false,
);
let y = x.matmul(&p.variable).unwrap();
y.sum().unwrap().backward().unwrap();
opt.step().unwrap();
let vals = p.variable.data().to_f32_vec().unwrap();
for (i, &v) in vals.iter().enumerate() {
assert!(v.is_finite(), "param[{}] is not finite: {}", i, v);
}
}
#[test]
fn test_empty_params_optimizers_no_panic() {
let empty: &[Parameter] = &[];
let mut adam = Adam::new(empty, 0.001);
adam.step().unwrap();
adam.zero_grad();
let mut sgd = SGD::new(empty, 0.01, 0.9);
sgd.step().unwrap();
sgd.zero_grad();
let mut adamw = AdamW::new(empty, 0.001, 0.01);
adamw.step().unwrap();
adamw.zero_grad();
let mut rmsprop = RMSprop::new(empty, 0.01);
rmsprop.step().unwrap();
rmsprop.zero_grad();
let mut adagrad = Adagrad::new(empty, 0.01);
adagrad.step().unwrap();
adagrad.zero_grad();
let mut radam = RAdam::new(empty, 0.01);
radam.step().unwrap();
radam.zero_grad();
let mut nadam = NAdam::new(empty, 0.01);
nadam.step().unwrap();
nadam.zero_grad();
}
#[test]
fn test_nadam_convergence_100_steps() {
use crate::nn::{Linear, Module, loss::mse_loss};
let dev = crate::tensor::test_device();
let model = Linear::on_device(4, 1, dev).unwrap();
let mut opt = NAdam::new(&model.parameters(), 0.01);
let x = Variable::new(
Tensor::from_f32(
&[1.0, 0.0, 0.0, 0.0,
0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 1.0],
&[4, 4], dev,
).unwrap(),
false,
);
let target = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[4, 1], dev).unwrap(),
false,
);
let first_loss;
{
let pred = model.forward(&x).unwrap();
first_loss = mse_loss(&pred, &target).unwrap().item().unwrap();
}
for _ in 0..100 {
opt.zero_grad();
let pred = model.forward(&x).unwrap();
let loss = mse_loss(&pred, &target).unwrap();
loss.backward().unwrap();
opt.step().unwrap();
}
let pred = model.forward(&x).unwrap();
let final_loss = mse_loss(&pred, &target).unwrap().item().unwrap();
assert!(final_loss < first_loss * 0.5,
"NAdam should converge: first={}, final={}", first_loss, final_loss);
}
#[test]
fn test_radam_convergence_100_steps() {
use crate::nn::{Linear, Module, loss::mse_loss};
let dev = crate::tensor::test_device();
let model = Linear::on_device(4, 1, dev).unwrap();
let mut opt = RAdam::new(&model.parameters(), 0.05);
let x = Variable::new(
Tensor::from_f32(
&[1.0, 0.0, 0.0, 0.0,
0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 1.0],
&[4, 4], dev,
).unwrap(),
false,
);
let target = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[4, 1], dev).unwrap(),
false,
);
let first_loss;
{
let pred = model.forward(&x).unwrap();
first_loss = mse_loss(&pred, &target).unwrap().item().unwrap();
}
for _ in 0..100 {
opt.zero_grad();
let pred = model.forward(&x).unwrap();
let loss = mse_loss(&pred, &target).unwrap();
loss.backward().unwrap();
opt.step().unwrap();
}
let pred = model.forward(&x).unwrap();
let final_loss = mse_loss(&pred, &target).unwrap().item().unwrap();
assert!(final_loss < first_loss * 0.5,
"RAdam should converge: first={}, final={}", first_loss, final_loss);
}
#[test]
fn test_adagrad_convergence_50_steps() {
use crate::nn::{Linear, Module, loss::mse_loss};
let dev = crate::tensor::test_device();
let model = Linear::on_device(4, 1, dev).unwrap();
let mut opt = Adagrad::new(&model.parameters(), 0.1);
let x = Variable::new(
Tensor::from_f32(
&[1.0, 0.0, 0.0, 0.0,
0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 1.0],
&[4, 4], dev,
).unwrap(),
false,
);
let target = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[4, 1], dev).unwrap(),
false,
);
let first_loss;
{
let pred = model.forward(&x).unwrap();
first_loss = mse_loss(&pred, &target).unwrap().item().unwrap();
}
for _ in 0..50 {
opt.zero_grad();
let pred = model.forward(&x).unwrap();
let loss = mse_loss(&pred, &target).unwrap();
loss.backward().unwrap();
opt.step().unwrap();
}
let pred = model.forward(&x).unwrap();
let final_loss = mse_loss(&pred, &target).unwrap().item().unwrap();
assert!(final_loss < first_loss * 0.5,
"Adagrad should converge: first={}, final={}", first_loss, final_loss);
}
#[test]
fn test_sgd_parameter_groups_different_lr() {
let dev = crate::tensor::test_device();
let p_fast = Parameter::new(
Tensor::from_f32(&[1.0, 2.0], &[1, 2], dev).unwrap(), "fast");
let p_slow = Parameter::new(
Tensor::from_f32(&[1.0, 2.0], &[1, 2], dev).unwrap(), "slow");
let mut opt = SGD::with_groups(0.0)
.group(std::slice::from_ref(&p_fast), 1.0)
.group(std::slice::from_ref(&p_slow), 0.001)
.build();
let x = Variable::new(
Tensor::from_f32(&[1.0], &[1, 1], dev).unwrap(), false,
);
let y_fast = x.matmul(&p_fast.variable).unwrap();
let y_slow = x.matmul(&p_slow.variable).unwrap();
let loss = y_fast.add(&y_slow).unwrap().sum().unwrap();
loss.backward().unwrap();
let fast_before = p_fast.variable.data().to_f32_vec().unwrap();
let slow_before = p_slow.variable.data().to_f32_vec().unwrap();
opt.step().unwrap();
let fast_after = p_fast.variable.data().to_f32_vec().unwrap();
let slow_after = p_slow.variable.data().to_f32_vec().unwrap();
let fast_delta: f64 = fast_before.iter().zip(&fast_after)
.map(|(a, b)| (a - b).abs() as f64).sum();
let slow_delta: f64 = slow_before.iter().zip(&slow_after)
.map(|(a, b)| (a - b).abs() as f64).sum();
assert!(fast_delta > slow_delta * 100.0,
"fast group (lr=1.0) should move much more than slow (lr=0.001): fast={}, slow={}",
fast_delta, slow_delta);
}
#[test]
fn test_step_after_zero_grad_on_fresh_optimizer() {
let p = make_param("w", &[3, 2]);
let mut adam = Adam::new(std::slice::from_ref(&p), 0.001);
let mut sgd = SGD::new(std::slice::from_ref(&p), 0.01, 0.9);
adam.zero_grad();
adam.step().unwrap();
sgd.zero_grad();
sgd.step().unwrap();
let vals = p.variable.data().to_f32_vec().unwrap();
for (i, &v) in vals.iter().enumerate() {
assert!(v.is_finite(), "param[{}] should be finite after step-without-backward: {}", i, v);
}
}
#[test]
fn test_double_step_without_backward_is_noop() {
let p = make_param("w", &[3, 2]);
let mut opt = Adam::new(std::slice::from_ref(&p), 0.01);
let x = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0], &[1, 3], crate::tensor::test_device()).unwrap(),
false,
);
let y = x.matmul(&p.variable).unwrap();
y.sum().unwrap().backward().unwrap();
opt.step().unwrap();
opt.zero_grad();
let after_first = p.variable.data().to_f32_vec().unwrap();
opt.step().unwrap();
let after_second = p.variable.data().to_f32_vec().unwrap();
assert_eq!(after_first, after_second,
"second step without backward should not change params");
}
#[test]
fn test_set_lr_all_optimizers() {
let p = make_param("w", &[2]);
let mut adam = Adam::new(std::slice::from_ref(&p), 0.001);
adam.set_lr(0.42);
assert!((adam.lr() - 0.42).abs() < 1e-12, "Adam set_lr failed");
let mut sgd = SGD::new(std::slice::from_ref(&p), 0.01, 0.0);
sgd.set_lr(0.42);
assert!((sgd.lr() - 0.42).abs() < 1e-12, "SGD set_lr failed");
let mut adamw = AdamW::new(std::slice::from_ref(&p), 0.001, 0.01);
adamw.set_lr(0.42);
assert!((adamw.lr() - 0.42).abs() < 1e-12, "AdamW set_lr failed");
let mut rmsprop = RMSprop::new(std::slice::from_ref(&p), 0.01);
rmsprop.set_lr(0.42);
assert!((rmsprop.lr() - 0.42).abs() < 1e-12, "RMSprop set_lr failed");
let mut nadam = NAdam::new(std::slice::from_ref(&p), 0.01);
nadam.set_lr(0.42);
assert!((nadam.lr() - 0.42).abs() < 1e-12, "NAdam set_lr failed");
let mut radam = RAdam::new(std::slice::from_ref(&p), 0.01);
radam.set_lr(0.42);
assert!((radam.lr() - 0.42).abs() < 1e-12, "RAdam set_lr failed");
let mut adagrad = Adagrad::new(std::slice::from_ref(&p), 0.01);
adagrad.set_lr(0.42);
assert!((adagrad.lr() - 0.42).abs() < 1e-12, "Adagrad set_lr failed");
}
#[test]
fn test_set_lr_affects_actual_update_magnitude() {
let dev = crate::tensor::test_device();
let p_lo = Parameter::new(
Tensor::from_f32(&[5.0], &[1], dev).unwrap(), "lo");
let p_hi = Parameter::new(
Tensor::from_f32(&[5.0], &[1], dev).unwrap(), "hi");
let mut opt_lo = SGD::new(std::slice::from_ref(&p_lo), 0.001, 0.0);
let mut opt_hi = SGD::new(std::slice::from_ref(&p_hi), 0.001, 0.0);
opt_hi.set_lr(1.0);
let x = Variable::new(
Tensor::from_f32(&[1.0], &[1], dev).unwrap(), false,
);
let loss_lo = x.mul(&p_lo.variable).unwrap().sum().unwrap();
loss_lo.backward().unwrap();
let loss_hi = x.mul(&p_hi.variable).unwrap().sum().unwrap();
loss_hi.backward().unwrap();
opt_lo.step().unwrap();
opt_hi.step().unwrap();
let val_lo = p_lo.variable.data().to_f32_vec().unwrap()[0];
let val_hi = p_hi.variable.data().to_f32_vec().unwrap()[0];
let delta_lo = (5.0 - val_lo).abs();
let delta_hi = (5.0 - val_hi).abs();
assert!(delta_hi > delta_lo * 100.0,
"set_lr(1.0) should produce much larger update than 0.001: hi={}, lo={}",
delta_hi, delta_lo);
}
}