use std::collections::HashMap;
use std::collections::hash_map::Entry;
use yscv_autograd::{Graph, NodeId};
use yscv_tensor::Tensor;
use super::validate::{validate_epsilon, validate_lr};
use super::{LearningRate, OptimError};
#[derive(Debug, Clone)]
struct AdagradState {
sum_sq: Tensor,
}
impl AdagradState {
fn new(shape: &[usize]) -> Result<Self, OptimError> {
Ok(Self {
sum_sq: Tensor::zeros(shape.to_vec())?,
})
}
fn reset(&mut self, shape: &[usize]) -> Result<(), OptimError> {
*self = Self::new(shape)?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct Adagrad {
lr: f32,
epsilon: f32,
weight_decay: f32,
state: HashMap<u64, AdagradState>,
}
impl Adagrad {
pub fn new(lr: f32) -> Result<Self, OptimError> {
validate_lr(lr)?;
Ok(Self {
lr,
epsilon: 1e-10,
weight_decay: 0.0,
state: HashMap::new(),
})
}
pub fn with_epsilon(mut self, epsilon: f32) -> Result<Self, OptimError> {
validate_epsilon(epsilon)?;
self.epsilon = epsilon;
Ok(self)
}
pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self, OptimError> {
if !weight_decay.is_finite() || weight_decay < 0.0 {
return Err(OptimError::InvalidWeightDecay { weight_decay });
}
self.weight_decay = weight_decay;
Ok(self)
}
pub fn clear_state(&mut self) {
self.state.clear();
}
pub fn learning_rate(&self) -> f32 {
self.lr
}
pub fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
validate_lr(lr)?;
self.lr = lr;
Ok(())
}
pub fn step(
&mut self,
parameter_id: u64,
weights: &mut Tensor,
grad: &Tensor,
) -> Result<(), OptimError> {
if weights.shape() != grad.shape() {
return Err(OptimError::ShapeMismatch {
weights: weights.shape().to_vec(),
grad: grad.shape().to_vec(),
});
}
let state = match self.state.entry(parameter_id) {
Entry::Occupied(entry) => entry.into_mut(),
Entry::Vacant(entry) => entry.insert(AdagradState::new(weights.shape())?),
};
if state.sum_sq.shape() != weights.shape() {
state.reset(weights.shape())?;
}
let sum_sq = state.sum_sq.data_mut();
let grad_values = grad.data();
let weights_data = weights.data_mut();
for index in 0..weights_data.len() {
let grad_value = grad_values[index] + self.weight_decay * weights_data[index];
sum_sq[index] += grad_value * grad_value;
weights_data[index] -= self.lr * grad_value / (sum_sq[index].sqrt() + self.epsilon);
}
Ok(())
}
pub fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), OptimError> {
if !graph.requires_grad(node)? {
return Ok(());
}
let grad = match graph.grad(node)? {
Some(grad) => grad.clone(),
None => return Err(OptimError::MissingGradient { node: node.0 }),
};
let weights = graph.value_mut(node)?;
self.step(node.0 as u64, weights, &grad)
}
}
impl LearningRate for Adagrad {
fn learning_rate(&self) -> f32 {
Adagrad::learning_rate(self)
}
fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
Adagrad::set_learning_rate(self, lr)
}
}