nabla_ml/nab_optimizers.rs
1use crate::nab_array::NDArray;
2
3pub struct NablaOptimizer;
4
5
6impl NablaOptimizer {
7
8 /// Performs Stochastic Gradient Descent (SGD) update
9 ///
10 /// w = w - learning_rate * gradient
11 ///
12 /// # Arguments
13 ///
14 /// * `weights` - NDArray of current weights to update
15 /// * `gradient` - NDArray of gradients for the weights
16 /// * `learning_rate` - Learning rate for the update
17 ///
18 /// # Example
19 ///
20 /// ```
21 /// use nabla_ml::nab_array::NDArray;
22 /// use nabla_ml::nab_optimizers::NablaOptimizer;
23 ///
24 /// let mut weights = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
25 /// let gradients = NDArray::from_vec(vec![0.1, 0.2, 0.3]);
26 /// let learning_rate = 0.1;
27 ///
28 /// NablaOptimizer::sgd_update(&mut weights, &gradients, learning_rate);
29 /// ```
30 pub fn sgd_update(weights: &mut NDArray, gradient: &NDArray, learning_rate: f64) {
31 let update = gradient.multiply_scalar(learning_rate);
32 *weights = weights.subtract(&update);
33 }
34
35 /// Performs SGD update with momentum
36 ///
37 /// v = momentum * v - learning_rate * gradient
38 /// w = w + v
39 ///
40 /// # Arguments
41 ///
42 /// * `weights` - NDArray of current weights to update
43 /// * `gradient` - NDArray of gradients for the weights
44 /// * `velocity` - Mutable reference to momentum velocity
45 /// * `learning_rate` - Learning rate for the update
46 /// * `momentum` - Momentum coefficient (default: 0.9)
47 pub fn sgd_momentum_update(
48 weights: &mut NDArray,
49 gradient: &NDArray,
50 velocity: &mut NDArray,
51 learning_rate: f64,
52 momentum: f64,
53 ) {
54 // Update velocity
55 *velocity = velocity.multiply_scalar(momentum)
56 .subtract(&gradient.multiply_scalar(learning_rate));
57
58 // Update weights using velocity
59 *weights = weights.clone().add(velocity);
60 }
61
62 /// Performs RMSprop update
63 ///
64 /// cache = decay_rate * cache + (1 - decay_rate) * gradient^2
65 /// w = w - learning_rate * gradient / (sqrt(cache) + epsilon)
66 ///
67 /// # Arguments
68 ///
69 /// * `weights` - NDArray of current weights to update
70 /// * `gradient` - NDArray of gradients for the weights
71 /// * `cache` - Running average of squared gradients
72 /// * `learning_rate` - Learning rate for the update
73 /// * `decay_rate` - Decay rate for running average (default: 0.9)
74 /// * `epsilon` - Small value for numerical stability (default: 1e-8)
75 ///
76 /// # Example
77 ///
78 /// ```
79 /// use nabla_ml::nab_array::NDArray;
80 /// use nabla_ml::nab_optimizers::NablaOptimizer;
81 ///
82 /// let mut weights = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
83 /// let gradients = NDArray::from_vec(vec![0.1, 0.2, 0.3]);
84 /// let mut cache = NDArray::zeros(vec![3]);
85 /// let learning_rate = 0.01;
86 /// let decay_rate = 0.9;
87 /// let epsilon = 1e-8;
88 ///
89 /// NablaOptimizer::rmsprop_update(
90 /// &mut weights,
91 /// &gradients,
92 /// &mut cache,
93 /// learning_rate,
94 /// decay_rate,
95 /// epsilon
96 /// );
97 /// ```
98 pub fn rmsprop_update(
99 weights: &mut NDArray,
100 gradient: &NDArray,
101 cache: &mut NDArray,
102 learning_rate: f64,
103 decay_rate: f64,
104 epsilon: f64,
105 ) {
106 // Update cache
107 *cache = cache.multiply_scalar(decay_rate)
108 .add(&gradient.multiply(gradient).multiply_scalar(1.0 - decay_rate));
109
110 // Compute update
111 let update = gradient.divide(
112 &cache.sqrt().add_scalar(epsilon)
113 ).multiply_scalar(learning_rate);
114
115 // Update weights
116 *weights = weights.subtract(&update);
117 }
118
119 /// Performs Adam (Adaptive Moment Estimation) update
120 ///
121 /// m = beta1 * m + (1 - beta1) * gradient // Update first moment
122 /// v = beta2 * v + (1 - beta2) * gradient^2 // Update second moment
123 /// m_hat = m / (1 - beta1^t) // Bias correction
124 /// v_hat = v / (1 - beta2^t) // Bias correction
125 /// w = w - learning_rate * m_hat / (sqrt(v_hat) + epsilon)
126 ///
127 /// # Arguments
128 ///
129 /// * `weights` - NDArray of current weights to update
130 /// * `gradient` - NDArray of gradients for the weights
131 /// * `m` - First moment vector (momentum)
132 /// * `v` - Second moment vector (uncentered variance)
133 /// * `t` - Current timestep (starting from 1)
134 /// * `learning_rate` - Learning rate for the update
135 /// * `beta1` - Exponential decay rate for first moment (default: 0.9)
136 /// * `beta2` - Exponential decay rate for second moment (default: 0.999)
137 /// * `epsilon` - Small value for numerical stability (default: 1e-8)
138 ///
139 /// # Example
140 ///
141 /// ```
142 /// use nabla_ml::nab_array::NDArray;
143 /// use nabla_ml::nab_optimizers::NablaOptimizer;
144 ///
145 /// let mut weights = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
146 /// let gradients = NDArray::from_vec(vec![0.1, 0.2, 0.3]);
147 /// let mut m = NDArray::zeros(vec![3]);
148 /// let mut v = NDArray::zeros(vec![3]);
149 /// let t = 1;
150 /// let learning_rate = 0.001;
151 /// let beta1 = 0.9;
152 /// let beta2 = 0.999;
153 /// let epsilon = 1e-8;
154 ///
155 /// NablaOptimizer::adam_update(
156 /// &mut weights,
157 /// &gradients,
158 /// &mut m,
159 /// &mut v,
160 /// t,
161 /// learning_rate,
162 /// beta1,
163 /// beta2,
164 /// epsilon
165 /// );
166 /// ```
167 pub fn adam_update(
168 weights: &mut NDArray,
169 gradient: &NDArray,
170 m: &mut NDArray,
171 v: &mut NDArray,
172 t: usize,
173 learning_rate: f64,
174 beta1: f64,
175 beta2: f64,
176 epsilon: f64,
177 ) {
178 // Update biased first moment estimate
179 *m = m.multiply_scalar(beta1)
180 .add(&gradient.multiply_scalar(1.0 - beta1));
181
182 // Update biased second raw moment estimate
183 *v = v.multiply_scalar(beta2)
184 .add(&gradient.multiply(gradient).multiply_scalar(1.0 - beta2));
185
186 // Compute bias-corrected first moment estimate
187 let m_hat = m.multiply_scalar(1.0 / (1.0 - beta1.powi(t as i32)));
188
189 // Compute bias-corrected second raw moment estimate
190 let v_hat = v.multiply_scalar(1.0 / (1.0 - beta2.powi(t as i32)));
191
192 // Compute the update
193 let update = m_hat.divide(&v_hat.sqrt().add_scalar(epsilon))
194 .multiply_scalar(learning_rate);
195
196 // Apply update to weights
197 *weights = weights.subtract(&update);
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 #[test]
206 fn test_sgd_update() {
207 // Initialize test data
208 let mut weights = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
209 let gradients = NDArray::from_vec(vec![0.1, 0.2, 0.3]);
210 let learning_rate = 0.1;
211
212 // Store initial weights
213 let initial_weights = weights.clone();
214
215 // Perform update
216 NablaOptimizer::sgd_update(&mut weights, &gradients, learning_rate);
217
218 // Verify weights were updated correctly
219 for i in 0..weights.data().len() {
220 let expected = initial_weights.data()[i] - learning_rate * gradients.data()[i];
221 assert!((weights.data()[i] - expected).abs() < 1e-6);
222 }
223 }
224
225 #[test]
226 fn test_sgd_momentum() {
227 // Initialize test data
228 let mut weights = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
229 let gradients = NDArray::from_vec(vec![0.1, 0.2, 0.3]);
230 let mut velocity = NDArray::zeros(vec![3]);
231 let learning_rate = 0.1;
232 let momentum = 0.9;
233
234 // Store initial weights
235 let initial_weights = weights.clone();
236
237 // Perform update
238 NablaOptimizer::sgd_momentum_update(
239 &mut weights,
240 &gradients,
241 &mut velocity,
242 learning_rate,
243 momentum
244 );
245
246 // Verify weights changed
247 assert!(weights.data() != initial_weights.data());
248
249 // Verify velocity is non-zero
250 assert!(velocity.data().iter().any(|&x| x != 0.0));
251
252 // Verify momentum effect (velocity should be -learning_rate * gradients)
253 for i in 0..velocity.data().len() {
254 assert!((velocity.data()[i] + learning_rate * gradients.data()[i]).abs() < 1e-6);
255 }
256 }
257
258 #[test]
259 fn test_rmsprop_update() {
260 // Initialize test data
261 let mut weights = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
262 let gradients = NDArray::from_vec(vec![0.1, 0.2, 0.3]);
263 let mut cache = NDArray::zeros(vec![3]);
264 let learning_rate = 0.01;
265 let decay_rate = 0.9;
266 let epsilon = 1e-8;
267
268 // Store initial values
269 let initial_weights = weights.clone();
270 let initial_cache = cache.clone();
271
272 // Perform update
273 NablaOptimizer::rmsprop_update(
274 &mut weights,
275 &gradients,
276 &mut cache,
277 learning_rate,
278 decay_rate,
279 epsilon
280 );
281
282 // Verify weights changed
283 assert!(weights.data() != initial_weights.data(),
284 "Weights should be updated");
285
286 // Verify cache was updated
287 assert!(cache.data() != initial_cache.data(),
288 "Cache should be updated");
289
290 // Verify cache contains squared gradient information
291 for i in 0..cache.data().len() {
292 let expected_cache = (1.0 - decay_rate) * gradients.data()[i].powi(2);
293 assert!((cache.data()[i] - expected_cache).abs() < 1e-6,
294 "Cache should contain squared gradient information");
295 }
296
297 // Test multiple updates to verify cache accumulation
298 let prev_cache = cache.clone();
299 NablaOptimizer::rmsprop_update(
300 &mut weights,
301 &gradients,
302 &mut cache,
303 learning_rate,
304 decay_rate,
305 epsilon
306 );
307
308 // Verify cache decay
309 for i in 0..cache.data().len() {
310 assert!(cache.data()[i] > prev_cache.data()[i],
311 "Cache should accumulate gradient information");
312 }
313 }
314
315 #[test]
316 fn test_adam_update() {
317 // Initialize test data
318 let mut weights = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
319 let gradients = NDArray::from_vec(vec![0.1, 0.2, 0.3]);
320 let mut m = NDArray::zeros(vec![3]);
321 let mut v = NDArray::zeros(vec![3]);
322 let t = 1;
323 let learning_rate = 0.001;
324 let beta1 = 0.9;
325 let beta2 = 0.999;
326 let epsilon = 1e-8;
327
328 // Store initial values
329 let initial_weights = weights.clone();
330 let initial_m = m.clone();
331 let initial_v = v.clone();
332
333 // Perform update
334 NablaOptimizer::adam_update(
335 &mut weights,
336 &gradients,
337 &mut m,
338 &mut v,
339 t,
340 learning_rate,
341 beta1,
342 beta2,
343 epsilon
344 );
345
346 // Verify weights changed
347 assert!(weights.data() != initial_weights.data(),
348 "Weights should be updated");
349
350 // Verify moment estimates changed
351 assert!(m.data() != initial_m.data(),
352 "First moment should be updated");
353 assert!(v.data() != initial_v.data(),
354 "Second moment should be updated");
355
356 // Verify first moment update
357 for i in 0..m.data().len() {
358 let expected_m = (1.0 - beta1) * gradients.data()[i];
359 assert!((m.data()[i] - expected_m).abs() < 1e-6,
360 "First moment should be correctly updated");
361 }
362
363 // Verify second moment update
364 for i in 0..v.data().len() {
365 let expected_v = (1.0 - beta2) * gradients.data()[i].powi(2);
366 assert!((v.data()[i] - expected_v).abs() < 1e-6,
367 "Second moment should be correctly updated");
368 }
369
370 // Test multiple updates
371 let prev_m = m.clone();
372 let prev_v = v.clone();
373
374 NablaOptimizer::adam_update(
375 &mut weights,
376 &gradients,
377 &mut m,
378 &mut v,
379 t + 1,
380 learning_rate,
381 beta1,
382 beta2,
383 epsilon
384 );
385
386 // Verify moment accumulation
387 assert!(m.data().iter().zip(prev_m.data().iter())
388 .all(|(&new, &old)| new != old),
389 "First moment should accumulate");
390 assert!(v.data().iter().zip(prev_v.data().iter())
391 .all(|(&new, &old)| new != old),
392 "Second moment should accumulate");
393 }
394}