nabla_ml/
nab_optimizers.rs

1use crate::nab_array::NDArray;
2
3pub struct NablaOptimizer;
4
5
6impl NablaOptimizer {
7
8    /// Performs Stochastic Gradient Descent (SGD) update
9    /// 
10    /// w = w - learning_rate * gradient
11    ///
12    /// # Arguments
13    ///
14    /// * `weights` - NDArray of current weights to update
15    /// * `gradient` - NDArray of gradients for the weights
16    /// * `learning_rate` - Learning rate for the update
17    ///
18    /// # Example
19    ///
20    /// ```
21    /// use nabla_ml::nab_array::NDArray;
22    /// use nabla_ml::nab_optimizers::NablaOptimizer;
23    ///
24    /// let mut weights = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
25    /// let gradients = NDArray::from_vec(vec![0.1, 0.2, 0.3]);
26    /// let learning_rate = 0.1;
27    ///
28    /// NablaOptimizer::sgd_update(&mut weights, &gradients, learning_rate);
29    /// ```
30    pub fn sgd_update(weights: &mut NDArray, gradient: &NDArray, learning_rate: f64) {
31        let update = gradient.multiply_scalar(learning_rate);
32        *weights = weights.subtract(&update);
33    }
34
35    /// Performs SGD update with momentum
36    /// 
37    /// v = momentum * v - learning_rate * gradient
38    /// w = w + v
39    ///
40    /// # Arguments
41    ///
42    /// * `weights` - NDArray of current weights to update
43    /// * `gradient` - NDArray of gradients for the weights
44    /// * `velocity` - Mutable reference to momentum velocity
45    /// * `learning_rate` - Learning rate for the update
46    /// * `momentum` - Momentum coefficient (default: 0.9)
47    pub fn sgd_momentum_update(
48        weights: &mut NDArray,
49        gradient: &NDArray,
50        velocity: &mut NDArray,
51        learning_rate: f64,
52        momentum: f64,
53    ) {
54        // Update velocity
55        *velocity = velocity.multiply_scalar(momentum)
56            .subtract(&gradient.multiply_scalar(learning_rate));
57        
58        // Update weights using velocity
59        *weights = weights.clone().add(velocity);
60    }
61
62    /// Performs RMSprop update
63    /// 
64    /// cache = decay_rate * cache + (1 - decay_rate) * gradient^2
65    /// w = w - learning_rate * gradient / (sqrt(cache) + epsilon)
66    ///
67    /// # Arguments
68    ///
69    /// * `weights` - NDArray of current weights to update
70    /// * `gradient` - NDArray of gradients for the weights
71    /// * `cache` - Running average of squared gradients
72    /// * `learning_rate` - Learning rate for the update
73    /// * `decay_rate` - Decay rate for running average (default: 0.9)
74    /// * `epsilon` - Small value for numerical stability (default: 1e-8)
75    ///
76    /// # Example
77    ///
78    /// ```
79    /// use nabla_ml::nab_array::NDArray;
80    /// use nabla_ml::nab_optimizers::NablaOptimizer;
81    ///
82    /// let mut weights = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
83    /// let gradients = NDArray::from_vec(vec![0.1, 0.2, 0.3]);
84    /// let mut cache = NDArray::zeros(vec![3]);
85    /// let learning_rate = 0.01;
86    /// let decay_rate = 0.9;
87    /// let epsilon = 1e-8;
88    ///
89    /// NablaOptimizer::rmsprop_update(
90    ///     &mut weights, 
91    ///     &gradients, 
92    ///     &mut cache,
93    ///     learning_rate,
94    ///     decay_rate,
95    ///     epsilon
96    /// );
97    /// ```
98    pub fn rmsprop_update(
99        weights: &mut NDArray,
100        gradient: &NDArray,
101        cache: &mut NDArray,
102        learning_rate: f64,
103        decay_rate: f64,
104        epsilon: f64,
105    ) {
106        // Update cache
107        *cache = cache.multiply_scalar(decay_rate)
108            .add(&gradient.multiply(gradient).multiply_scalar(1.0 - decay_rate));
109        
110        // Compute update
111        let update = gradient.divide(
112            &cache.sqrt().add_scalar(epsilon)
113        ).multiply_scalar(learning_rate);
114        
115        // Update weights
116        *weights = weights.subtract(&update);
117    }
118
119    /// Performs Adam (Adaptive Moment Estimation) update
120    /// 
121    /// m = beta1 * m + (1 - beta1) * gradient           // Update first moment
122    /// v = beta2 * v + (1 - beta2) * gradient^2         // Update second moment
123    /// m_hat = m / (1 - beta1^t)                        // Bias correction
124    /// v_hat = v / (1 - beta2^t)                        // Bias correction
125    /// w = w - learning_rate * m_hat / (sqrt(v_hat) + epsilon)
126    ///
127    /// # Arguments
128    ///
129    /// * `weights` - NDArray of current weights to update
130    /// * `gradient` - NDArray of gradients for the weights
131    /// * `m` - First moment vector (momentum)
132    /// * `v` - Second moment vector (uncentered variance)
133    /// * `t` - Current timestep (starting from 1)
134    /// * `learning_rate` - Learning rate for the update
135    /// * `beta1` - Exponential decay rate for first moment (default: 0.9)
136    /// * `beta2` - Exponential decay rate for second moment (default: 0.999)
137    /// * `epsilon` - Small value for numerical stability (default: 1e-8)
138    ///
139    /// # Example
140    ///
141    /// ```
142    /// use nabla_ml::nab_array::NDArray;
143    /// use nabla_ml::nab_optimizers::NablaOptimizer;
144    ///
145    /// let mut weights = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
146    /// let gradients = NDArray::from_vec(vec![0.1, 0.2, 0.3]);
147    /// let mut m = NDArray::zeros(vec![3]);
148    /// let mut v = NDArray::zeros(vec![3]);
149    /// let t = 1;
150    /// let learning_rate = 0.001;
151    /// let beta1 = 0.9;
152    /// let beta2 = 0.999;
153    /// let epsilon = 1e-8;
154    ///
155    /// NablaOptimizer::adam_update(
156    ///     &mut weights,
157    ///     &gradients,
158    ///     &mut m,
159    ///     &mut v,
160    ///     t,
161    ///     learning_rate,
162    ///     beta1,
163    ///     beta2,
164    ///     epsilon
165    /// );
166    /// ```
167    pub fn adam_update(
168        weights: &mut NDArray,
169        gradient: &NDArray,
170        m: &mut NDArray,
171        v: &mut NDArray,
172        t: usize,
173        learning_rate: f64,
174        beta1: f64,
175        beta2: f64,
176        epsilon: f64,
177    ) {
178        // Update biased first moment estimate
179        *m = m.multiply_scalar(beta1)
180            .add(&gradient.multiply_scalar(1.0 - beta1));
181        
182        // Update biased second raw moment estimate
183        *v = v.multiply_scalar(beta2)
184            .add(&gradient.multiply(gradient).multiply_scalar(1.0 - beta2));
185        
186        // Compute bias-corrected first moment estimate
187        let m_hat = m.multiply_scalar(1.0 / (1.0 - beta1.powi(t as i32)));
188        
189        // Compute bias-corrected second raw moment estimate
190        let v_hat = v.multiply_scalar(1.0 / (1.0 - beta2.powi(t as i32)));
191        
192        // Compute the update
193        let update = m_hat.divide(&v_hat.sqrt().add_scalar(epsilon))
194            .multiply_scalar(learning_rate);
195        
196        // Apply update to weights
197        *weights = weights.subtract(&update);
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    #[test]
206    fn test_sgd_update() {
207        // Initialize test data
208        let mut weights = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
209        let gradients = NDArray::from_vec(vec![0.1, 0.2, 0.3]);
210        let learning_rate = 0.1;
211
212        // Store initial weights
213        let initial_weights = weights.clone();
214
215        // Perform update
216        NablaOptimizer::sgd_update(&mut weights, &gradients, learning_rate);
217
218        // Verify weights were updated correctly
219        for i in 0..weights.data().len() {
220            let expected = initial_weights.data()[i] - learning_rate * gradients.data()[i];
221            assert!((weights.data()[i] - expected).abs() < 1e-6);
222        }
223    }
224
225    #[test]
226    fn test_sgd_momentum() {
227        // Initialize test data
228        let mut weights = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
229        let gradients = NDArray::from_vec(vec![0.1, 0.2, 0.3]);
230        let mut velocity = NDArray::zeros(vec![3]);
231        let learning_rate = 0.1;
232        let momentum = 0.9;
233
234        // Store initial weights
235        let initial_weights = weights.clone();
236
237        // Perform update
238        NablaOptimizer::sgd_momentum_update(
239            &mut weights,
240            &gradients,
241            &mut velocity,
242            learning_rate,
243            momentum
244        );
245
246        // Verify weights changed
247        assert!(weights.data() != initial_weights.data());
248
249        // Verify velocity is non-zero
250        assert!(velocity.data().iter().any(|&x| x != 0.0));
251
252        // Verify momentum effect (velocity should be -learning_rate * gradients)
253        for i in 0..velocity.data().len() {
254            assert!((velocity.data()[i] + learning_rate * gradients.data()[i]).abs() < 1e-6);
255        }
256    }
257
258    #[test]
259    fn test_rmsprop_update() {
260        // Initialize test data
261        let mut weights = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
262        let gradients = NDArray::from_vec(vec![0.1, 0.2, 0.3]);
263        let mut cache = NDArray::zeros(vec![3]);
264        let learning_rate = 0.01;
265        let decay_rate = 0.9;
266        let epsilon = 1e-8;
267
268        // Store initial values
269        let initial_weights = weights.clone();
270        let initial_cache = cache.clone();
271
272        // Perform update
273        NablaOptimizer::rmsprop_update(
274            &mut weights,
275            &gradients,
276            &mut cache,
277            learning_rate,
278            decay_rate,
279            epsilon
280        );
281
282        // Verify weights changed
283        assert!(weights.data() != initial_weights.data(),
284            "Weights should be updated");
285
286        // Verify cache was updated
287        assert!(cache.data() != initial_cache.data(),
288            "Cache should be updated");
289
290        // Verify cache contains squared gradient information
291        for i in 0..cache.data().len() {
292            let expected_cache = (1.0 - decay_rate) * gradients.data()[i].powi(2);
293            assert!((cache.data()[i] - expected_cache).abs() < 1e-6,
294                "Cache should contain squared gradient information");
295        }
296
297        // Test multiple updates to verify cache accumulation
298        let prev_cache = cache.clone();
299        NablaOptimizer::rmsprop_update(
300            &mut weights,
301            &gradients,
302            &mut cache,
303            learning_rate,
304            decay_rate,
305            epsilon
306        );
307
308        // Verify cache decay
309        for i in 0..cache.data().len() {
310            assert!(cache.data()[i] > prev_cache.data()[i],
311                "Cache should accumulate gradient information");
312        }
313    }
314
315    #[test]
316    fn test_adam_update() {
317        // Initialize test data
318        let mut weights = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
319        let gradients = NDArray::from_vec(vec![0.1, 0.2, 0.3]);
320        let mut m = NDArray::zeros(vec![3]);
321        let mut v = NDArray::zeros(vec![3]);
322        let t = 1;
323        let learning_rate = 0.001;
324        let beta1 = 0.9;
325        let beta2 = 0.999;
326        let epsilon = 1e-8;
327
328        // Store initial values
329        let initial_weights = weights.clone();
330        let initial_m = m.clone();
331        let initial_v = v.clone();
332
333        // Perform update
334        NablaOptimizer::adam_update(
335            &mut weights,
336            &gradients,
337            &mut m,
338            &mut v,
339            t,
340            learning_rate,
341            beta1,
342            beta2,
343            epsilon
344        );
345
346        // Verify weights changed
347        assert!(weights.data() != initial_weights.data(),
348            "Weights should be updated");
349
350        // Verify moment estimates changed
351        assert!(m.data() != initial_m.data(),
352            "First moment should be updated");
353        assert!(v.data() != initial_v.data(),
354            "Second moment should be updated");
355
356        // Verify first moment update
357        for i in 0..m.data().len() {
358            let expected_m = (1.0 - beta1) * gradients.data()[i];
359            assert!((m.data()[i] - expected_m).abs() < 1e-6,
360                "First moment should be correctly updated");
361        }
362
363        // Verify second moment update
364        for i in 0..v.data().len() {
365            let expected_v = (1.0 - beta2) * gradients.data()[i].powi(2);
366            assert!((v.data()[i] - expected_v).abs() < 1e-6,
367                "Second moment should be correctly updated");
368        }
369
370        // Test multiple updates
371        let prev_m = m.clone();
372        let prev_v = v.clone();
373        
374        NablaOptimizer::adam_update(
375            &mut weights,
376            &gradients,
377            &mut m,
378            &mut v,
379            t + 1,
380            learning_rate,
381            beta1,
382            beta2,
383            epsilon
384        );
385
386        // Verify moment accumulation
387        assert!(m.data().iter().zip(prev_m.data().iter())
388            .all(|(&new, &old)| new != old),
389            "First moment should accumulate");
390        assert!(v.data().iter().zip(prev_v.data().iter())
391            .all(|(&new, &old)| new != old),
392            "Second moment should accumulate");
393    }
394}