1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3
4use yscv_autograd::{Graph, NodeId};
5use yscv_tensor::Tensor;
6
7use super::validate::{validate_epsilon, validate_lr, validate_momentum, validate_rmsprop_alpha};
8use super::{LearningRate, OptimError};
9
10#[derive(Debug, Clone)]
11struct RmsPropState {
12 square_avg: Tensor,
13 grad_avg: Tensor,
14 momentum_buffer: Tensor,
15}
16
17impl RmsPropState {
18 fn new(shape: &[usize]) -> Result<Self, OptimError> {
19 Ok(Self {
20 square_avg: Tensor::zeros(shape.to_vec())?,
21 grad_avg: Tensor::zeros(shape.to_vec())?,
22 momentum_buffer: Tensor::zeros(shape.to_vec())?,
23 })
24 }
25
26 fn reset(&mut self, shape: &[usize]) -> Result<(), OptimError> {
27 *self = Self::new(shape)?;
28 Ok(())
29 }
30}
31
32#[derive(Debug, Clone)]
34pub struct RmsProp {
35 lr: f32,
36 alpha: f32,
37 epsilon: f32,
38 weight_decay: f32,
39 momentum: f32,
40 centered: bool,
41 state: HashMap<u64, RmsPropState>,
42}
43
44impl RmsProp {
45 pub fn new(lr: f32) -> Result<Self, OptimError> {
47 validate_lr(lr)?;
48 Ok(Self {
49 lr,
50 alpha: 0.99,
51 epsilon: 1e-8,
52 weight_decay: 0.0,
53 momentum: 0.0,
54 centered: false,
55 state: HashMap::new(),
56 })
57 }
58
59 pub fn with_alpha(mut self, alpha: f32) -> Result<Self, OptimError> {
61 validate_rmsprop_alpha(alpha)?;
62 self.alpha = alpha;
63 Ok(self)
64 }
65
66 pub fn with_epsilon(mut self, epsilon: f32) -> Result<Self, OptimError> {
68 validate_epsilon(epsilon)?;
69 self.epsilon = epsilon;
70 Ok(self)
71 }
72
73 pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self, OptimError> {
75 if !weight_decay.is_finite() || weight_decay < 0.0 {
76 return Err(OptimError::InvalidWeightDecay { weight_decay });
77 }
78 self.weight_decay = weight_decay;
79 Ok(self)
80 }
81
82 pub fn with_momentum(mut self, momentum: f32) -> Result<Self, OptimError> {
84 validate_momentum(momentum)?;
85 self.momentum = momentum;
86 Ok(self)
87 }
88
89 pub fn with_centered(mut self, centered: bool) -> Self {
91 self.centered = centered;
92 self
93 }
94
95 pub fn clear_state(&mut self) {
97 self.state.clear();
98 }
99
100 pub fn learning_rate(&self) -> f32 {
102 self.lr
103 }
104
105 pub fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
107 validate_lr(lr)?;
108 self.lr = lr;
109 Ok(())
110 }
111
112 pub fn step(
114 &mut self,
115 parameter_id: u64,
116 weights: &mut Tensor,
117 grad: &Tensor,
118 ) -> Result<(), OptimError> {
119 if weights.shape() != grad.shape() {
120 return Err(OptimError::ShapeMismatch {
121 weights: weights.shape().to_vec(),
122 grad: grad.shape().to_vec(),
123 });
124 }
125
126 let state = match self.state.entry(parameter_id) {
127 Entry::Occupied(entry) => entry.into_mut(),
128 Entry::Vacant(entry) => entry.insert(RmsPropState::new(weights.shape())?),
129 };
130 if state.square_avg.shape() != weights.shape() {
131 state.reset(weights.shape())?;
132 }
133
134 let grad_values = grad.data();
135 let weights_data = weights.data_mut();
136 let square_avg = state.square_avg.data_mut();
137 let grad_avg = state.grad_avg.data_mut();
138 let momentum_buffer = state.momentum_buffer.data_mut();
139
140 let alpha = self.alpha;
141 let one_minus_alpha = 1.0 - self.alpha;
142
143 for index in 0..weights_data.len() {
144 let grad_value = grad_values[index] + self.weight_decay * weights_data[index];
145 square_avg[index] =
146 alpha * square_avg[index] + one_minus_alpha * grad_value * grad_value;
147
148 let avg = if self.centered {
149 grad_avg[index] = alpha * grad_avg[index] + one_minus_alpha * grad_value;
150 (square_avg[index] - grad_avg[index] * grad_avg[index]).max(0.0)
151 } else {
152 square_avg[index]
153 };
154
155 let denom = avg.sqrt() + self.epsilon;
156 let normalized = grad_value / denom;
157 let update = if self.momentum != 0.0 {
158 let next = self.momentum * momentum_buffer[index] + normalized;
159 momentum_buffer[index] = next;
160 next
161 } else {
162 normalized
163 };
164 weights_data[index] -= self.lr * update;
165 }
166
167 Ok(())
168 }
169
170 pub fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), OptimError> {
172 if !graph.requires_grad(node)? {
173 return Ok(());
174 }
175
176 let grad = match graph.grad(node)? {
177 Some(grad) => grad.clone(),
178 None => return Err(OptimError::MissingGradient { node: node.0 }),
179 };
180 let weights = graph.value_mut(node)?;
181 self.step(node.0 as u64, weights, &grad)
182 }
183}
184
185impl LearningRate for RmsProp {
186 fn learning_rate(&self) -> f32 {
187 RmsProp::learning_rate(self)
188 }
189
190 fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
191 RmsProp::set_learning_rate(self, lr)
192 }
193}