use axonml_nn::Parameter;
use axonml_tensor::Tensor;
use crate::optimizer::Optimizer;
use axonml_core;
pub struct RMSprop {
params: Vec<Parameter>,
lr: f32,
alpha: f32,
eps: f32,
weight_decay: f32,
momentum: f32,
centered: bool,
state: Vec<RMSpropState>,
}
#[derive(Debug, Clone)]
struct RMSpropState {
square_avg: Tensor<f32>,
momentum_buffer: Option<Tensor<f32>>,
grad_avg: Option<Tensor<f32>>,
}
impl RMSpropState {
fn new(shape: &[usize], device: axonml_core::Device, momentum: bool, centered: bool) -> Self {
let square_avg = {
let t = Tensor::zeros(shape);
if device.is_gpu() {
t.to_device(device).unwrap()
} else {
t
}
};
let momentum_buffer = if momentum {
let t = Tensor::zeros(shape);
Some(if device.is_gpu() {
t.to_device(device).unwrap()
} else {
t
})
} else {
None
};
let grad_avg = if centered {
let t = Tensor::zeros(shape);
Some(if device.is_gpu() {
t.to_device(device).unwrap()
} else {
t
})
} else {
None
};
Self {
square_avg,
momentum_buffer,
grad_avg,
}
}
}
impl RMSprop {
#[must_use]
pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
Self {
params,
lr,
alpha: 0.99,
eps: 1e-8,
weight_decay: 0.0,
momentum: 0.0,
centered: false,
state: Vec::new(),
}
}
#[must_use]
pub fn with_alpha(params: Vec<Parameter>, lr: f32, alpha: f32) -> Self {
Self {
params,
lr,
alpha,
eps: 1e-8,
weight_decay: 0.0,
momentum: 0.0,
centered: false,
state: Vec::new(),
}
}
#[must_use]
pub fn with_options(
params: Vec<Parameter>,
lr: f32,
alpha: f32,
eps: f32,
weight_decay: f32,
momentum: f32,
centered: bool,
) -> Self {
Self {
params,
lr,
alpha,
eps,
weight_decay,
momentum,
centered,
state: Vec::new(),
}
}
#[must_use]
pub fn alpha(mut self, alpha: f32) -> Self {
self.alpha = alpha;
self
}
#[must_use]
pub fn eps(mut self, eps: f32) -> Self {
self.eps = eps;
self
}
#[must_use]
pub fn weight_decay(mut self, weight_decay: f32) -> Self {
self.weight_decay = weight_decay;
self
}
#[must_use]
pub fn momentum(mut self, momentum: f32) -> Self {
self.momentum = momentum;
self
}
#[must_use]
pub fn centered(mut self, centered: bool) -> Self {
self.centered = centered;
self
}
fn ensure_state_initialized(&mut self) {
if self.state.is_empty() {
self.state = self
.params
.iter()
.map(|p| {
let data = p.data();
RMSpropState::new(
data.shape(),
data.device(),
self.momentum != 0.0,
self.centered,
)
})
.collect();
}
}
}
impl Optimizer for RMSprop {
fn step(&mut self) {
self.ensure_state_initialized();
for (i, param) in self.params.iter().enumerate() {
if !param.requires_grad() {
continue;
}
let grad = match param.grad() {
Some(g) => g,
None => continue,
};
let param_data = param.data();
let state = &mut self.state[i];
let d = if self.weight_decay == 0.0 {
grad.clone()
} else {
grad.add(¶m_data.mul_scalar(self.weight_decay)).unwrap()
};
let d_sq = d.mul(&d).unwrap();
state.square_avg = state
.square_avg
.mul_scalar(self.alpha)
.add(&d_sq.mul_scalar(1.0 - self.alpha))
.unwrap();
let denom = if self.centered {
let grad_avg = state.grad_avg.as_mut().unwrap();
*grad_avg = grad_avg
.mul_scalar(self.alpha)
.add(&d.mul_scalar(1.0 - self.alpha))
.unwrap();
let ga_sq = grad_avg.mul(grad_avg).unwrap();
state
.square_avg
.sub(&ga_sq)
.unwrap()
.sqrt()
.add_scalar(self.eps)
} else {
state.square_avg.sqrt().add_scalar(self.eps)
};
let update = if self.momentum == 0.0 {
d.div(&denom).unwrap()
} else {
let normalized = d.div(&denom).unwrap();
let buf = state.momentum_buffer.as_mut().unwrap();
*buf = buf.mul_scalar(self.momentum).add(&normalized).unwrap();
buf.clone()
};
let new_param = param_data.sub(&update.mul_scalar(self.lr)).unwrap();
param.update_data(new_param);
}
}
fn zero_grad(&mut self) {
for param in &self.params {
param.zero_grad();
}
}
fn get_lr(&self) -> f32 {
self.lr
}
fn set_lr(&mut self, lr: f32) {
self.lr = lr;
}
fn parameters(&self) -> &[Parameter] {
&self.params
}
}
#[cfg(test)]
mod tests {
use super::*;
use axonml_autograd::Variable;
#[test]
fn test_rmsprop_creation() {
let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
let param = Parameter::from_variable(var);
let optimizer = RMSprop::new(vec![param], 0.01);
assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
assert!((optimizer.alpha - 0.99).abs() < 1e-6);
}
#[test]
fn test_rmsprop_step() {
let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
let param = Parameter::from_variable(var);
param
.variable()
.set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
let mut optimizer = RMSprop::new(vec![param.clone()], 0.01);
optimizer.step();
let new_data = param.data().to_vec();
assert!((new_data[0] - 1.0).abs() > 1e-6);
}
#[test]
fn test_rmsprop_with_momentum() {
let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
let param = Parameter::from_variable(var);
let optimizer = RMSprop::new(vec![param], 0.01).momentum(0.9);
assert!((optimizer.momentum - 0.9).abs() < 1e-6);
}
#[test]
fn test_rmsprop_centered() {
let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
let param = Parameter::from_variable(var);
let optimizer = RMSprop::new(vec![param], 0.01).centered(true);
assert!(optimizer.centered);
}
#[test]
fn test_rmsprop_builder_pattern() {
let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
let param = Parameter::from_variable(var);
let optimizer = RMSprop::new(vec![param], 0.01)
.alpha(0.95)
.eps(1e-6)
.weight_decay(0.0001)
.momentum(0.9)
.centered(true);
assert!((optimizer.alpha - 0.95).abs() < 1e-6);
assert!((optimizer.eps - 1e-6).abs() < 1e-9);
assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
assert!((optimizer.momentum - 0.9).abs() < 1e-6);
assert!(optimizer.centered);
}
}