use zyx::Tensor;
use zyx_derive::Module;
#[derive(Module)]
#[cfg_attr(feature = "py", pyo3::pyclass)]
pub struct RMSprop {
pub learning_rate: f32,
pub alpha: f32,
pub eps: f32,
pub momentum: f32,
pub centered: bool,
pub weight_decay: f32,
pub t: usize,
buffer: Vec<Tensor>,
momentum_buf: Vec<Tensor>,
grad_avg: Vec<Tensor>,
}
impl Default for RMSprop {
fn default() -> Self {
Self {
learning_rate: 0.01,
alpha: 0.99,
eps: 1e-8,
momentum: 0.0,
centered: false,
weight_decay: 0.0,
t: 0,
buffer: Vec::new(),
momentum_buf: Vec::new(),
grad_avg: Vec::new(),
}
}
}
impl RMSprop {
pub fn update<'a>(
&mut self,
parameters: impl IntoIterator<Item = &'a mut Tensor>,
gradients: impl IntoIterator<Item = Option<Tensor>>,
) {
for (i, (param, grad)) in parameters.into_iter().zip(gradients).enumerate() {
let Some(grad) = grad else {
if self.buffer.len() <= i {
self.buffer.push(Tensor::zeros_like(&*param));
self.momentum_buf.push(Tensor::zeros_like(&*param));
if self.centered {
self.grad_avg.push(Tensor::zeros_like(&*param));
}
}
continue;
};
if self.buffer.len() <= i {
self.buffer.push(&grad * &grad * (1.0 - self.alpha));
self.momentum_buf.push(Tensor::zeros_like(&*param));
if self.centered {
self.grad_avg.push(&grad * (1.0 - self.alpha));
}
}
self.buffer[i] = &self.buffer[i] * self.alpha + &grad * &grad * (1.0 - self.alpha);
let denom = if self.centered {
self.grad_avg[i] = &self.grad_avg[i] * self.alpha + &grad * (1.0 - self.alpha);
let avg = &self.grad_avg[i];
(&self.buffer[i] - avg * avg).relu().sqrt() + self.eps
} else {
self.buffer[i].sqrt() + self.eps
};
let update = &grad / denom * self.learning_rate;
if self.momentum > 0.0 {
self.momentum_buf[i] = &self.momentum_buf[i] * self.momentum + &update;
*param = &*param - &self.momentum_buf[i];
} else {
*param = &*param - update;
}
if self.weight_decay > 0.0 {
*param = &*param * (1.0 - self.learning_rate * self.weight_decay);
}
}
self.t += 1;
}
}