use std::collections::HashMap;
use std::collections::hash_map::Entry;
use yscv_autograd::{Graph, NodeId};
use yscv_tensor::Tensor;
use super::validate::{validate_epsilon, validate_lr, validate_momentum, validate_rmsprop_alpha};
use super::{LearningRate, OptimError};
#[derive(Debug, Clone)]
struct RmsPropState {
square_avg: Tensor,
grad_avg: Tensor,
momentum_buffer: Tensor,
}
impl RmsPropState {
fn new(shape: &[usize]) -> Result<Self, OptimError> {
Ok(Self {
square_avg: Tensor::zeros(shape.to_vec())?,
grad_avg: Tensor::zeros(shape.to_vec())?,
momentum_buffer: Tensor::zeros(shape.to_vec())?,
})
}
fn reset(&mut self, shape: &[usize]) -> Result<(), OptimError> {
*self = Self::new(shape)?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct RmsProp {
lr: f32,
alpha: f32,
epsilon: f32,
weight_decay: f32,
momentum: f32,
centered: bool,
state: HashMap<u64, RmsPropState>,
}
impl RmsProp {
pub fn new(lr: f32) -> Result<Self, OptimError> {
validate_lr(lr)?;
Ok(Self {
lr,
alpha: 0.99,
epsilon: 1e-8,
weight_decay: 0.0,
momentum: 0.0,
centered: false,
state: HashMap::new(),
})
}
pub fn with_alpha(mut self, alpha: f32) -> Result<Self, OptimError> {
validate_rmsprop_alpha(alpha)?;
self.alpha = alpha;
Ok(self)
}
pub fn with_epsilon(mut self, epsilon: f32) -> Result<Self, OptimError> {
validate_epsilon(epsilon)?;
self.epsilon = epsilon;
Ok(self)
}
pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self, OptimError> {
if !weight_decay.is_finite() || weight_decay < 0.0 {
return Err(OptimError::InvalidWeightDecay { weight_decay });
}
self.weight_decay = weight_decay;
Ok(self)
}
pub fn with_momentum(mut self, momentum: f32) -> Result<Self, OptimError> {
validate_momentum(momentum)?;
self.momentum = momentum;
Ok(self)
}
pub fn with_centered(mut self, centered: bool) -> Self {
self.centered = centered;
self
}
pub fn clear_state(&mut self) {
self.state.clear();
}
pub fn learning_rate(&self) -> f32 {
self.lr
}
pub fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
validate_lr(lr)?;
self.lr = lr;
Ok(())
}
pub fn step(
&mut self,
parameter_id: u64,
weights: &mut Tensor,
grad: &Tensor,
) -> Result<(), OptimError> {
if weights.shape() != grad.shape() {
return Err(OptimError::ShapeMismatch {
weights: weights.shape().to_vec(),
grad: grad.shape().to_vec(),
});
}
let state = match self.state.entry(parameter_id) {
Entry::Occupied(entry) => entry.into_mut(),
Entry::Vacant(entry) => entry.insert(RmsPropState::new(weights.shape())?),
};
if state.square_avg.shape() != weights.shape() {
state.reset(weights.shape())?;
}
let grad_values = grad.data();
let weights_data = weights.data_mut();
let square_avg = state.square_avg.data_mut();
let grad_avg = state.grad_avg.data_mut();
let momentum_buffer = state.momentum_buffer.data_mut();
let alpha = self.alpha;
let one_minus_alpha = 1.0 - self.alpha;
for index in 0..weights_data.len() {
let grad_value = grad_values[index] + self.weight_decay * weights_data[index];
square_avg[index] =
alpha * square_avg[index] + one_minus_alpha * grad_value * grad_value;
let avg = if self.centered {
grad_avg[index] = alpha * grad_avg[index] + one_minus_alpha * grad_value;
(square_avg[index] - grad_avg[index] * grad_avg[index]).max(0.0)
} else {
square_avg[index]
};
let denom = avg.sqrt() + self.epsilon;
let normalized = grad_value / denom;
let update = if self.momentum != 0.0 {
let next = self.momentum * momentum_buffer[index] + normalized;
momentum_buffer[index] = next;
next
} else {
normalized
};
weights_data[index] -= self.lr * update;
}
Ok(())
}
pub fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), OptimError> {
if !graph.requires_grad(node)? {
return Ok(());
}
let grad = match graph.grad(node)? {
Some(grad) => grad.clone(),
None => return Err(OptimError::MissingGradient { node: node.0 }),
};
let weights = graph.value_mut(node)?;
self.step(node.0 as u64, weights, &grad)
}
}
impl LearningRate for RmsProp {
fn learning_rate(&self) -> f32 {
RmsProp::learning_rate(self)
}
fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
RmsProp::set_learning_rate(self, lr)
}
}