non_convex_opt/algorithms/adam/
adam_opt.rs1use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, OMatrix, OVector, U1};
2
3use crate::utils::config::AdamConf;
4use crate::utils::opt_prob::{FloatNumber as FloatNum, OptProb, OptimizationAlgorithm, State};
5
6pub struct Adam<T, N, D>
7where
8 T: FloatNum,
9 N: Dim,
10 D: Dim,
11 OVector<T, D>: Send + Sync,
12 OMatrix<T, N, D>: Send + Sync,
13 DefaultAllocator: Allocator<D> + Allocator<N, D> + Allocator<N>,
14{
15 pub conf: AdamConf,
16 pub st: State<T, N, D>,
17 pub opt_prob: OptProb<T, D>,
18 m: OVector<T, D>, v: OVector<T, D>, v_hat: OVector<T, D>, }
22
23impl<T, N, D> Adam<T, N, D>
24where
25 T: FloatNum,
26 N: Dim,
27 D: Dim,
28 OVector<T, D>: Send + Sync,
29 OMatrix<T, N, D>: Send + Sync,
30 DefaultAllocator: Allocator<D> + Allocator<N, D> + Allocator<U1, D> + Allocator<N>,
31{
32 pub fn new(conf: AdamConf, init_pop: OMatrix<T, U1, D>, opt_prob: OptProb<T, D>) -> Self {
33 let init_x: OVector<T, D> = init_pop.row(0).transpose().into_owned();
34 let best_f = opt_prob.evaluate(&init_x);
35 let n = init_x.len();
36
37 Self {
38 conf,
39 st: State {
40 best_x: init_x.clone(),
41 best_f,
42 pop: OMatrix::<T, N, D>::from_fn_generic(
43 N::from_usize(1),
44 D::from_usize(n),
45 |_, j| init_x.clone()[j],
46 ),
47 fitness: OVector::<T, N>::from_element_generic(N::from_usize(1), U1, best_f),
48 constraints: OVector::<bool, N>::from_element_generic(
49 N::from_usize(1),
50 U1,
51 opt_prob.is_feasible(&init_x.clone()),
52 ),
53 iter: 1,
54 },
55 opt_prob,
56 m: OVector::zeros_generic(D::from_usize(n), U1),
57 v: OVector::zeros_generic(D::from_usize(n), U1),
58 v_hat: OVector::zeros_generic(D::from_usize(n), U1),
59 }
60 }
61}
62
63impl<T, N, D> OptimizationAlgorithm<T, N, D> for Adam<T, N, D>
64where
65 T: FloatNum,
66 N: Dim,
67 D: Dim,
68 OVector<T, D>: Send + Sync,
69 OMatrix<T, N, D>: Send + Sync,
70 DefaultAllocator: Allocator<D> + Allocator<N> + Allocator<N, D> + Allocator<U1, D>,
71{
72 fn step(&mut self) {
73 let mut grad = self
74 .opt_prob
75 .objective
76 .gradient(&self.st.best_x)
77 .expect("ADAM requires gradient information");
78
79 if self.conf.weight_decay > 0.0 {
81 let weight_decay = T::from_f64(self.conf.weight_decay).unwrap();
82 grad += &self.st.best_x * weight_decay;
83 }
84
85 if self.conf.gradient_clip > 0.0 {
87 let clip_norm = T::from_f64(self.conf.gradient_clip).unwrap();
88 let grad_norm = grad.dot(&grad).sqrt();
89 if grad_norm > clip_norm {
90 grad *= clip_norm / grad_norm;
91 }
92 }
93
94 self.m = self.m.clone() * T::from_f64(self.conf.beta1).unwrap()
96 + grad.clone() * T::from_f64(1.0 - self.conf.beta1).unwrap();
97 self.v = self.v.clone() * T::from_f64(self.conf.beta2).unwrap()
98 + grad.component_mul(&grad) * T::from_f64(1.0 - self.conf.beta2).unwrap();
99
100 let m_hat = self.m.clone()
102 / (T::one() - T::from_f64(self.conf.beta1.powi(self.st.iter as i32)).unwrap());
103 let v_hat = self.v.clone()
104 / (T::one() - T::from_f64(self.conf.beta2.powi(self.st.iter as i32)).unwrap());
105
106 if self.conf.amsgrad {
108 for i in 0..self.v_hat.len() {
109 self.v_hat[i] = self.v_hat[i].max(v_hat[i]);
110 }
111 }
112
113 let step_size = T::from_f64(self.conf.learning_rate).unwrap();
114 let epsilon = T::from_f64(self.conf.epsilon).unwrap();
115
116 let v_denom = if self.conf.amsgrad {
117 &self.v_hat
118 } else {
119 &v_hat
120 };
121 let update = m_hat.component_div(&v_denom.map(|x| x.sqrt() + epsilon)) * step_size;
122 self.st.best_x += update;
123
124 if let Some(ref constraints) = self.opt_prob.constraints {
126 if !constraints.g(&self.st.best_x) {
127 if let (Some(lb), Some(ub)) = (
128 self.opt_prob.objective.x_lower_bound(&self.st.best_x),
129 self.opt_prob.objective.x_upper_bound(&self.st.best_x),
130 ) {
131 for i in 0..self.st.best_x.len() {
132 self.st.best_x[i] = self.st.best_x[i].max(lb[i]).min(ub[i]);
133 }
134 }
135 }
136 }
137
138 let fitness = self.opt_prob.evaluate(&self.st.best_x);
139 if fitness > self.st.best_f {
140 self.st.best_f = fitness;
141 self.st.best_x = self.st.best_x.clone();
142 }
143
144 self.st
145 .pop
146 .row_mut(0)
147 .copy_from(&self.st.best_x.transpose());
148 self.st.fitness[0] = fitness;
149 self.st.constraints[0] = self.opt_prob.is_feasible(&self.st.best_x);
150
151 self.st.iter += 1;
152 }
153
154 fn state(&self) -> &State<T, N, D>
155 where
156 DefaultAllocator: Allocator<N> + Allocator<D> + Allocator<N, D>,
157 {
158 &self.st
159 }
160}