opensrdk_optimization/
sgd_adam.rs1use crate::{vec::Vector as _, Status};
2use opensrdk_linear_algebra::*;
3use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng};
4
5pub struct SgdAdam {
16 epsilon: f64,
17 max_iter: usize,
18 alpha: f64,
19 beta1: f64,
20 beta2: f64,
21 e: f64,
22}
23
24impl Default for SgdAdam {
25 fn default() -> Self {
26 Self {
27 epsilon: 1e-6,
28 max_iter: 0,
29 alpha: 0.001,
30 beta1: 0.9,
31 beta2: 0.999,
32 e: 0.00000001,
33 }
34 }
35}
36
37impl SgdAdam {
38 pub fn new(epsilon: f64, max_iter: usize, alpha: f64, beta1: f64, beta2: f64, e: f64) -> Self {
39 Self {
40 epsilon,
41 max_iter,
42 alpha,
43 beta1,
44 beta2,
45 e,
46 }
47 }
48
49 pub fn with_epsilon(mut self, epsilon: f64) -> Self {
50 assert!(epsilon.is_sign_positive(), "epsilon must be positive");
51
52 self.epsilon = epsilon;
53
54 self
55 }
56
57 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
58 self.max_iter = max_iter;
59
60 self
61 }
62
63 pub fn with_alpha(mut self, alpha: f64) -> Self {
64 assert!(alpha.is_sign_positive(), "delta must be positive");
65
66 self.alpha = alpha;
67
68 self
69 }
70
71 pub fn with_beta1(mut self, beta1: f64) -> Self {
72 assert!(beta1.is_sign_positive(), "beta1 must be positive");
73
74 self.beta2 = beta1;
75
76 self
77 }
78
79 pub fn with_beta2(mut self, beta2: f64) -> Self {
80 assert!(beta2.is_sign_positive(), "beta2 must be positive");
81
82 self.beta2 = beta2;
83
84 self
85 }
86
87 pub fn with_e(mut self, e: f64) -> Self {
88 assert!(e.is_sign_positive(), "e must be positive");
89
90 self.e = e;
91
92 self
93 }
94
95 pub fn minimize(
102 &self,
103 x: &mut [f64],
104 grad: &dyn Fn(&[usize], &[f64]) -> Vec<f64>,
105 batch: usize,
106 total: usize,
107 ) -> Status {
108 let mut batch_index = (0..total).into_iter().collect::<Vec<_>>();
109 let mut w = x.to_vec().col_mat();
110 let mut m = Matrix::new(w.rows(), 1);
111 let mut v = Matrix::new(w.rows(), 1);
112 let mut rng: StdRng = SeedableRng::seed_from_u64(1);
113
114 for k in 0.. {
115 if self.max_iter != 0 && self.max_iter <= k {
116 return Status::MaxIter;
117 }
118
119 let gfx = grad(&(0..total).into_iter().collect::<Vec<_>>(), w.slice());
120 if gfx.l2_norm() < self.epsilon + x.l2_norm() {
121 return Status::Epsilon;
122 }
123
124 batch_index.shuffle(&mut rng);
125
126 for minibatch in batch_index.chunks(batch) {
127 let minibatch_grad = grad(&minibatch, w.slice()).col_mat();
128
129 m = self.beta1 * m + (1.0 - self.beta1) * minibatch_grad.clone();
130 v = self.beta2 * v
131 + (1.0 - self.beta2) * minibatch_grad.clone().hadamard_prod(&minibatch_grad);
132
133 let m_hat = m.clone() * (1.0 / (1.0 - self.beta1.powi(k as i32 + 1)));
134 let v_hat = v.clone() * (1.0 / (1.0 - self.beta2.powi(k as i32 + 1)));
135 let v_hat_sqrt_e_inv = v_hat
136 .vec()
137 .iter()
138 .map(|vi| 1.0 / (vi.sqrt() + self.e))
139 .collect::<Vec<_>>()
140 .col_mat();
141
142 w = w - self.alpha * m_hat.hadamard_prod(&v_hat_sqrt_e_inv);
143
144 if !w.slice().l2_norm().is_finite() {
145 return Status::NaN;
146 }
147 }
148
149 x.clone_from_slice(w.slice());
150 }
151
152 Status::Success
153 }
154}