1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3
4use yscv_autograd::{Graph, NodeId};
5use yscv_tensor::Tensor;
6
7use super::validate::{validate_epsilon, validate_lr};
8use super::{LearningRate, OptimError};
9
10#[derive(Debug, Clone)]
11struct AdagradState {
12 sum_sq: Tensor,
13}
14
15impl AdagradState {
16 fn new(shape: &[usize]) -> Result<Self, OptimError> {
17 Ok(Self {
18 sum_sq: Tensor::zeros(shape.to_vec())?,
19 })
20 }
21
22 fn reset(&mut self, shape: &[usize]) -> Result<(), OptimError> {
23 *self = Self::new(shape)?;
24 Ok(())
25 }
26}
27
28#[derive(Debug, Clone)]
30pub struct Adagrad {
31 lr: f32,
32 epsilon: f32,
33 weight_decay: f32,
34 state: HashMap<u64, AdagradState>,
35}
36
37impl Adagrad {
38 pub fn new(lr: f32) -> Result<Self, OptimError> {
40 validate_lr(lr)?;
41 Ok(Self {
42 lr,
43 epsilon: 1e-10,
44 weight_decay: 0.0,
45 state: HashMap::new(),
46 })
47 }
48
49 pub fn with_epsilon(mut self, epsilon: f32) -> Result<Self, OptimError> {
51 validate_epsilon(epsilon)?;
52 self.epsilon = epsilon;
53 Ok(self)
54 }
55
56 pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self, OptimError> {
58 if !weight_decay.is_finite() || weight_decay < 0.0 {
59 return Err(OptimError::InvalidWeightDecay { weight_decay });
60 }
61 self.weight_decay = weight_decay;
62 Ok(self)
63 }
64
65 pub fn clear_state(&mut self) {
67 self.state.clear();
68 }
69
70 pub fn learning_rate(&self) -> f32 {
72 self.lr
73 }
74
75 pub fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
77 validate_lr(lr)?;
78 self.lr = lr;
79 Ok(())
80 }
81
82 pub fn step(
84 &mut self,
85 parameter_id: u64,
86 weights: &mut Tensor,
87 grad: &Tensor,
88 ) -> Result<(), OptimError> {
89 if weights.shape() != grad.shape() {
90 return Err(OptimError::ShapeMismatch {
91 weights: weights.shape().to_vec(),
92 grad: grad.shape().to_vec(),
93 });
94 }
95
96 let state = match self.state.entry(parameter_id) {
97 Entry::Occupied(entry) => entry.into_mut(),
98 Entry::Vacant(entry) => entry.insert(AdagradState::new(weights.shape())?),
99 };
100 if state.sum_sq.shape() != weights.shape() {
101 state.reset(weights.shape())?;
102 }
103
104 let sum_sq = state.sum_sq.data_mut();
105 let grad_values = grad.data();
106 let weights_data = weights.data_mut();
107
108 for index in 0..weights_data.len() {
109 let grad_value = grad_values[index] + self.weight_decay * weights_data[index];
110 sum_sq[index] += grad_value * grad_value;
111 weights_data[index] -= self.lr * grad_value / (sum_sq[index].sqrt() + self.epsilon);
112 }
113
114 Ok(())
115 }
116
117 pub fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), OptimError> {
119 if !graph.requires_grad(node)? {
120 return Ok(());
121 }
122
123 let grad = match graph.grad(node)? {
124 Some(grad) => grad.clone(),
125 None => return Err(OptimError::MissingGradient { node: node.0 }),
126 };
127 let weights = graph.value_mut(node)?;
128 self.step(node.0 as u64, weights, &grad)
129 }
130}
131
132impl LearningRate for Adagrad {
133 fn learning_rate(&self) -> f32 {
134 Adagrad::learning_rate(self)
135 }
136
137 fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
138 Adagrad::set_learning_rate(self, lr)
139 }
140}