opensrdk_optimization/
sgd_adam.rs

1use crate::{vec::Vector as _, Status};
2use opensrdk_linear_algebra::*;
3use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng};
4
5/// # Stochastic Gradient Descent Adam
6///
7/// If you are indecisive for values, it is sufficient to use `default()`.
8///
9/// - `epsilon`: Threshold of the norm of gradients for finishing
10/// - `max_iter`: Max limitation of iteration count
11/// - `alpha`: Learning rate
12/// - `beta1`: Weight for moving average of Momentum
13/// - `beta2`: Weight for moving average of RMSProp
14/// - `e`: A small value for preventing zero divisions
15pub struct SgdAdam {
16    epsilon: f64,
17    max_iter: usize,
18    alpha: f64,
19    beta1: f64,
20    beta2: f64,
21    e: f64,
22}
23
24impl Default for SgdAdam {
25    fn default() -> Self {
26        Self {
27            epsilon: 1e-6,
28            max_iter: 0,
29            alpha: 0.001,
30            beta1: 0.9,
31            beta2: 0.999,
32            e: 0.00000001,
33        }
34    }
35}
36
37impl SgdAdam {
38    pub fn new(epsilon: f64, max_iter: usize, alpha: f64, beta1: f64, beta2: f64, e: f64) -> Self {
39        Self {
40            epsilon,
41            max_iter,
42            alpha,
43            beta1,
44            beta2,
45            e,
46        }
47    }
48
49    pub fn with_epsilon(mut self, epsilon: f64) -> Self {
50        assert!(epsilon.is_sign_positive(), "epsilon must be positive");
51
52        self.epsilon = epsilon;
53
54        self
55    }
56
57    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
58        self.max_iter = max_iter;
59
60        self
61    }
62
63    pub fn with_alpha(mut self, alpha: f64) -> Self {
64        assert!(alpha.is_sign_positive(), "delta must be positive");
65
66        self.alpha = alpha;
67
68        self
69    }
70
71    pub fn with_beta1(mut self, beta1: f64) -> Self {
72        assert!(beta1.is_sign_positive(), "beta1 must be positive");
73
74        self.beta2 = beta1;
75
76        self
77    }
78
79    pub fn with_beta2(mut self, beta2: f64) -> Self {
80        assert!(beta2.is_sign_positive(), "beta2 must be positive");
81
82        self.beta2 = beta2;
83
84        self
85    }
86
87    pub fn with_e(mut self, e: f64) -> Self {
88        assert!(e.is_sign_positive(), "e must be positive");
89
90        self.e = e;
91
92        self
93    }
94
95    /// - `x`: Objective variables
96    /// - `grad` Returns the gradients of each inputs
97    ///   - `&[usize]`: Set of indices included in the batch
98    ///   - `&[f64]`: Objective variables as an input
99    /// - `batch`: Batch size
100    /// - `total`: Total data size
101    pub fn minimize(
102        &self,
103        x: &mut [f64],
104        grad: &dyn Fn(&[usize], &[f64]) -> Vec<f64>,
105        batch: usize,
106        total: usize,
107    ) -> Status {
108        let mut batch_index = (0..total).into_iter().collect::<Vec<_>>();
109        let mut w = x.to_vec().col_mat();
110        let mut m = Matrix::new(w.rows(), 1);
111        let mut v = Matrix::new(w.rows(), 1);
112        let mut rng: StdRng = SeedableRng::seed_from_u64(1);
113
114        for k in 0.. {
115            if self.max_iter != 0 && self.max_iter <= k {
116                return Status::MaxIter;
117            }
118
119            let gfx = grad(&(0..total).into_iter().collect::<Vec<_>>(), w.slice());
120            if gfx.l2_norm() < self.epsilon + x.l2_norm() {
121                return Status::Epsilon;
122            }
123
124            batch_index.shuffle(&mut rng);
125
126            for minibatch in batch_index.chunks(batch) {
127                let minibatch_grad = grad(&minibatch, w.slice()).col_mat();
128
129                m = self.beta1 * m + (1.0 - self.beta1) * minibatch_grad.clone();
130                v = self.beta2 * v
131                    + (1.0 - self.beta2) * minibatch_grad.clone().hadamard_prod(&minibatch_grad);
132
133                let m_hat = m.clone() * (1.0 / (1.0 - self.beta1.powi(k as i32 + 1)));
134                let v_hat = v.clone() * (1.0 / (1.0 - self.beta2.powi(k as i32 + 1)));
135                let v_hat_sqrt_e_inv = v_hat
136                    .vec()
137                    .iter()
138                    .map(|vi| 1.0 / (vi.sqrt() + self.e))
139                    .collect::<Vec<_>>()
140                    .col_mat();
141
142                w = w - self.alpha * m_hat.hadamard_prod(&v_hat_sqrt_e_inv);
143
144                if !w.slice().l2_norm().is_finite() {
145                    return Status::NaN;
146                }
147            }
148
149            x.clone_from_slice(w.slice());
150        }
151
152        Status::Success
153    }
154}