non_convex_opt/algorithms/adam/
adam_opt.rs

1use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, OMatrix, OVector, U1};
2
3use crate::utils::config::AdamConf;
4use crate::utils::opt_prob::{FloatNumber as FloatNum, OptProb, OptimizationAlgorithm, State};
5
6pub struct Adam<T, N, D>
7where
8    T: FloatNum,
9    N: Dim,
10    D: Dim,
11    OVector<T, D>: Send + Sync,
12    OMatrix<T, N, D>: Send + Sync,
13    DefaultAllocator: Allocator<D> + Allocator<N, D> + Allocator<N>,
14{
15    pub conf: AdamConf,
16    pub st: State<T, N, D>,
17    pub opt_prob: OptProb<T, D>,
18    m: OVector<T, D>,     // First moment estimate
19    v: OVector<T, D>,     // Second moment estimate
20    v_hat: OVector<T, D>, // Max of second moment estimate (for AMSGrad)
21}
22
23impl<T, N, D> Adam<T, N, D>
24where
25    T: FloatNum,
26    N: Dim,
27    D: Dim,
28    OVector<T, D>: Send + Sync,
29    OMatrix<T, N, D>: Send + Sync,
30    DefaultAllocator: Allocator<D> + Allocator<N, D> + Allocator<U1, D> + Allocator<N>,
31{
32    pub fn new(conf: AdamConf, init_pop: OMatrix<T, U1, D>, opt_prob: OptProb<T, D>) -> Self {
33        let init_x: OVector<T, D> = init_pop.row(0).transpose().into_owned();
34        let best_f = opt_prob.evaluate(&init_x);
35        let n = init_x.len();
36
37        Self {
38            conf,
39            st: State {
40                best_x: init_x.clone(),
41                best_f,
42                pop: OMatrix::<T, N, D>::from_fn_generic(
43                    N::from_usize(1),
44                    D::from_usize(n),
45                    |_, j| init_x.clone()[j],
46                ),
47                fitness: OVector::<T, N>::from_element_generic(N::from_usize(1), U1, best_f),
48                constraints: OVector::<bool, N>::from_element_generic(
49                    N::from_usize(1),
50                    U1,
51                    opt_prob.is_feasible(&init_x.clone()),
52                ),
53                iter: 1,
54            },
55            opt_prob,
56            m: OVector::zeros_generic(D::from_usize(n), U1),
57            v: OVector::zeros_generic(D::from_usize(n), U1),
58            v_hat: OVector::zeros_generic(D::from_usize(n), U1),
59        }
60    }
61}
62
63impl<T, N, D> OptimizationAlgorithm<T, N, D> for Adam<T, N, D>
64where
65    T: FloatNum,
66    N: Dim,
67    D: Dim,
68    OVector<T, D>: Send + Sync,
69    OMatrix<T, N, D>: Send + Sync,
70    DefaultAllocator: Allocator<D> + Allocator<N> + Allocator<N, D> + Allocator<U1, D>,
71{
72    fn step(&mut self) {
73        let mut grad = self
74            .opt_prob
75            .objective
76            .gradient(&self.st.best_x)
77            .expect("ADAM requires gradient information");
78
79        // Weight decay
80        if self.conf.weight_decay > 0.0 {
81            let weight_decay = T::from_f64(self.conf.weight_decay).unwrap();
82            grad += &self.st.best_x * weight_decay;
83        }
84
85        // Grad clip
86        if self.conf.gradient_clip > 0.0 {
87            let clip_norm = T::from_f64(self.conf.gradient_clip).unwrap();
88            let grad_norm = grad.dot(&grad).sqrt();
89            if grad_norm > clip_norm {
90                grad *= clip_norm / grad_norm;
91            }
92        }
93
94        // Biased moment estimates
95        self.m = self.m.clone() * T::from_f64(self.conf.beta1).unwrap()
96            + grad.clone() * T::from_f64(1.0 - self.conf.beta1).unwrap();
97        self.v = self.v.clone() * T::from_f64(self.conf.beta2).unwrap()
98            + grad.component_mul(&grad) * T::from_f64(1.0 - self.conf.beta2).unwrap();
99
100        // Bias-corrected moment estimates
101        let m_hat = self.m.clone()
102            / (T::one() - T::from_f64(self.conf.beta1.powi(self.st.iter as i32)).unwrap());
103        let v_hat = self.v.clone()
104            / (T::one() - T::from_f64(self.conf.beta2.powi(self.st.iter as i32)).unwrap());
105
106        // AMSGrad: use max of v_hat
107        if self.conf.amsgrad {
108            for i in 0..self.v_hat.len() {
109                self.v_hat[i] = self.v_hat[i].max(v_hat[i]);
110            }
111        }
112
113        let step_size = T::from_f64(self.conf.learning_rate).unwrap();
114        let epsilon = T::from_f64(self.conf.epsilon).unwrap();
115
116        let v_denom = if self.conf.amsgrad {
117            &self.v_hat
118        } else {
119            &v_hat
120        };
121        let update = m_hat.component_div(&v_denom.map(|x| x.sqrt() + epsilon)) * step_size;
122        self.st.best_x += update;
123
124        // Clamp onto feasible set
125        if let Some(ref constraints) = self.opt_prob.constraints {
126            if !constraints.g(&self.st.best_x) {
127                if let (Some(lb), Some(ub)) = (
128                    self.opt_prob.objective.x_lower_bound(&self.st.best_x),
129                    self.opt_prob.objective.x_upper_bound(&self.st.best_x),
130                ) {
131                    for i in 0..self.st.best_x.len() {
132                        self.st.best_x[i] = self.st.best_x[i].max(lb[i]).min(ub[i]);
133                    }
134                }
135            }
136        }
137
138        let fitness = self.opt_prob.evaluate(&self.st.best_x);
139        if fitness > self.st.best_f {
140            self.st.best_f = fitness;
141            self.st.best_x = self.st.best_x.clone();
142        }
143
144        self.st
145            .pop
146            .row_mut(0)
147            .copy_from(&self.st.best_x.transpose());
148        self.st.fitness[0] = fitness;
149        self.st.constraints[0] = self.opt_prob.is_feasible(&self.st.best_x);
150
151        self.st.iter += 1;
152    }
153
154    fn state(&self) -> &State<T, N, D>
155    where
156        DefaultAllocator: Allocator<N> + Allocator<D> + Allocator<N, D>,
157    {
158        &self.st
159    }
160}