use crate::core::{error::BellandeError, tensor::Tensor};
use crate::optim::{Optimizer, OptimizerState, ParameterGroup};
use std::collections::HashMap;
pub struct RMSprop {
params: Vec<Tensor>,
lr: f32,
alpha: f32,
eps: f32,
weight_decay: f32,
momentum: f32,
centered: bool,
v: HashMap<usize, Vec<f32>>, g: HashMap<usize, Vec<f32>>, buf: HashMap<usize, Vec<f32>>, param_groups: Vec<ParameterGroup>,
state: OptimizerState,
}
impl RMSprop {
pub fn new(
params: Vec<Tensor>,
lr: f32,
alpha: f32,
eps: f32,
weight_decay: f32,
momentum: f32,
centered: bool,
) -> Self {
let mut v = HashMap::new();
let mut g = HashMap::new();
let mut buf = HashMap::new();
for (idx, param) in params.iter().enumerate() {
v.insert(idx, vec![0.0; param.data.len()]);
if centered {
g.insert(idx, vec![0.0; param.data.len()]);
}
if momentum > 0.0 {
buf.insert(idx, vec![0.0; param.data.len()]);
}
}
let default_group = ParameterGroup::new(params.clone())
.with_lr(lr)
.with_weight_decay(weight_decay)
.with_momentum(momentum)
.with_eps(eps);
RMSprop {
params,
lr,
alpha,
eps,
weight_decay,
momentum,
centered,
v,
g,
buf,
param_groups: vec![default_group],
state: OptimizerState::new(),
}
}
}
impl Optimizer for RMSprop {
fn step(&mut self) -> Result<(), BellandeError> {
self.state.increment_step();
for (idx, param) in self.params.iter_mut().enumerate() {
if let Some(grad) = ¶m.grad {
let v = self.v.get_mut(&idx).unwrap();
let mut g = if self.centered {
Some(self.g.get_mut(&idx).unwrap())
} else {
None
};
let mut buf = if self.momentum > 0.0 {
Some(self.buf.get_mut(&idx).unwrap())
} else {
None
};
for i in 0..param.data.len() {
let grad_val = grad[i];
let mut final_grad = grad_val;
if self.weight_decay != 0.0 {
final_grad += self.weight_decay * param.data[i];
}
v[i] = self.alpha * v[i] + (1.0 - self.alpha) * final_grad * final_grad;
if let Some(g_avg) = &mut g {
g_avg[i] = self.alpha * g_avg[i] + (1.0 - self.alpha) * final_grad;
let denom = (v[i].sqrt() - g_avg[i].powi(2) + self.eps).sqrt();
final_grad /= denom;
} else {
final_grad /= (v[i] + self.eps).sqrt();
}
if let Some(buf_val) = &mut buf {
buf_val[i] = self.momentum * buf_val[i] + final_grad;
param.data[i] -= self.lr * buf_val[i];
} else {
param.data[i] -= self.lr * final_grad;
}
}
}
}
Ok(())
}
fn zero_grad(&mut self) {
for param in &mut self.params {
if let Some(grad) = &mut param.grad {
grad.iter_mut().for_each(|g| *g = 0.0);
}
}
}
fn get_lr(&self) -> f32 {
self.lr
}
fn set_lr(&mut self, lr: f32) {
self.lr = lr;
for group in &mut self.param_groups {
group.lr = lr;
}
}
fn parameters(&self) -> &Vec<Tensor> {
&self.params
}
fn parameters_mut(&mut self) -> &mut Vec<Tensor> {
&mut self.params
}
fn name(&self) -> &str {
"RMSprop"
}
fn get_param_groups(&self) -> &[ParameterGroup] {
&self.param_groups
}
fn get_param_groups_mut(&mut self) -> &mut [ParameterGroup] {
&mut self.param_groups
}
fn add_param_group(&mut self, mut group: ParameterGroup) {
let start_idx = self.params.len();
for (i, param) in group.params.iter().enumerate() {
self.v.insert(start_idx + i, vec![0.0; param.data.len()]);
if self.centered {
self.g.insert(start_idx + i, vec![0.0; param.data.len()]);
}
if self.momentum > 0.0 {
self.buf.insert(start_idx + i, vec![0.0; param.data.len()]);
}
}
self.params.extend(group.params.clone());
self.param_groups.push(group);
}
fn state(&self) -> &OptimizerState {
&self.state
}
fn state_mut(&mut self) -> &mut OptimizerState {
&mut self.state
}
}
unsafe impl Send for RMSprop {}
unsafe impl Sync for RMSprop {}