use crate::{
algorithms::gradient::GradientStatus,
core::{Callbacks, MinimizationSummary},
error::{GaneshError, GaneshResult},
traits::{
Algorithm, CostFunction, Gradient, Status, SupportsParameterNames, SupportsTransform,
Terminator, Transform, TransformedProblem,
},
DMatrix, DVector, Float,
};
use std::ops::ControlFlow;
#[derive(Copy, Clone)]
pub struct AdamEMATerminator {
pub beta_c: Float,
pub eps_loss: Float,
pub patience: usize,
}
impl Default for AdamEMATerminator {
fn default() -> Self {
Self {
beta_c: 0.9,
eps_loss: Float::EPSILON.sqrt(),
patience: 1,
}
}
}
impl<P, U, E> Terminator<Adam, P, GradientStatus, U, E, AdamConfig> for AdamEMATerminator
where
P: Gradient<U, E>,
{
fn check_for_termination(
&mut self,
_current_step: usize,
algorithm: &mut Adam,
_problem: &P,
status: &mut GradientStatus,
_args: &U,
_config: &AdamConfig,
) -> ControlFlow<()> {
let prev_ema_loss = algorithm.ema_loss;
algorithm.ema_loss = self
.beta_c
.mul_add(prev_ema_loss, (1.0 - self.beta_c) * algorithm.f);
if (algorithm.ema_loss - prev_ema_loss).abs() < self.eps_loss {
algorithm.ema_counter += 1;
} else {
algorithm.ema_counter = 0;
}
if algorithm.ema_counter >= self.patience {
status.set_message().succeed_with_message(&format!(
"EMA LOSS HAS NOT IMPROVED IN {} STEPS",
algorithm.ema_counter
));
return ControlFlow::Break(());
}
ControlFlow::Continue(())
}
}
#[derive(Clone)]
pub struct AdamConfig {
parameter_names: Option<Vec<String>>,
transform: Option<Box<dyn Transform>>,
alpha: Float,
beta_1: Float,
beta_2: Float,
epsilon: Float,
}
impl AdamConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_alpha(mut self, value: Float) -> GaneshResult<Self> {
if value <= 0.0 {
return Err(GaneshError::ConfigError(
"Initial learning rate must be positive and greater than 0".to_string(),
));
}
self.alpha = value;
Ok(self)
}
pub fn with_beta_1(mut self, value: Float) -> GaneshResult<Self> {
if !(0.0..1.0).contains(&value) {
return Err(GaneshError::ConfigError(
"beta_1 must be in the range [0, 1)".to_string(),
));
}
self.beta_1 = value;
Ok(self)
}
pub fn with_beta_2(mut self, value: Float) -> GaneshResult<Self> {
if !(0.0..1.0).contains(&value) {
return Err(GaneshError::ConfigError(
"beta_2 must be in the range [0, 1)".to_string(),
));
}
self.beta_2 = value;
Ok(self)
}
pub fn with_epsilon(mut self, value: Float) -> GaneshResult<Self> {
if value <= 0.0 {
return Err(GaneshError::ConfigError(
"Divide-by-zero tolerance must be positive and greater than 0".to_string(),
));
}
self.epsilon = value;
Ok(self)
}
}
impl Default for AdamConfig {
fn default() -> Self {
Self {
parameter_names: None,
transform: None,
alpha: 0.001,
beta_1: 0.9,
beta_2: 0.999,
epsilon: 1e-8, }
}
}
impl SupportsTransform for AdamConfig {
fn get_transform_mut(&mut self) -> &mut Option<Box<dyn Transform>> {
&mut self.transform
}
}
impl SupportsParameterNames for AdamConfig {
fn get_parameter_names_mut(&mut self) -> &mut Option<Vec<String>> {
&mut self.parameter_names
}
}
#[derive(Clone, Default)]
pub struct Adam {
x: DVector<Float>,
f: Float,
g: DVector<Float>,
m: DVector<Float>,
v: DVector<Float>,
ema_loss: Float,
ema_counter: usize,
}
impl<P, U, E> Algorithm<P, GradientStatus, U, E> for Adam
where
P: Gradient<U, E>,
{
type Summary = MinimizationSummary;
type Config = AdamConfig;
type Init = DVector<Float>;
fn initialize(
&mut self,
problem: &P,
status: &mut GradientStatus,
args: &U,
init: &Self::Init,
config: &Self::Config,
) -> Result<(), E> {
let t_problem = TransformedProblem::new(problem, &config.transform);
self.x = t_problem.to_owned_internal(init);
self.g = DVector::zeros(self.x.len());
self.f = t_problem.evaluate(&self.x, args)?;
status.initialize((init.clone(), self.f));
status.inc_n_f_evals();
self.m = DVector::zeros(self.x.len());
self.v = DVector::zeros(self.x.len());
Ok(())
}
fn step(
&mut self,
i_step: usize,
problem: &P,
status: &mut GradientStatus,
args: &U,
config: &Self::Config,
) -> Result<(), E> {
let t_problem = TransformedProblem::new(problem, &config.transform);
self.g = t_problem.gradient(&self.x, args)?;
status.inc_n_g_evals();
self.m = self.m.scale(config.beta_1) + self.g.scale(1.0 - config.beta_1);
self.v =
self.v.scale(config.beta_2) + self.g.map(|gi| gi.powi(2)).scale(1.0 - config.beta_2);
let alpha_t = config.alpha * (1.0 - config.beta_2.powi(i_step as i32 + 1)).sqrt()
/ (1.0 - config.beta_1.powi(i_step as i32 + 1));
self.x -= self
.m
.scale(alpha_t)
.component_div(&self.v.map(|vi| vi.sqrt() + config.epsilon));
self.f = t_problem.evaluate(&self.x, args)?;
status.inc_n_f_evals();
status.set_position((t_problem.to_owned_external(&self.x), self.f));
Ok(())
}
fn summarize(
&self,
_current_step: usize,
_problem: &P,
status: &GradientStatus,
_args: &U,
init: &Self::Init,
config: &Self::Config,
) -> Result<Self::Summary, E> {
Ok(MinimizationSummary {
x0: init.clone(),
x: status.x.clone(),
fx: status.fx,
bounds: None,
n_f_evals: status.n_f_evals,
n_g_evals: status.n_g_evals,
n_h_evals: status.n_h_evals,
message: status.message.clone(),
parameter_names: config.parameter_names.clone(),
std: status
.err
.clone()
.unwrap_or_else(|| DVector::from_element(status.x.len(), 0.0)),
covariance: status
.cov
.clone()
.unwrap_or_else(|| DMatrix::identity(status.x.len(), status.x.len())),
})
}
fn reset(&mut self) {
self.ema_loss = 0.0;
self.ema_counter = 0;
}
fn default_callbacks() -> Callbacks<Self, P, GradientStatus, U, E, Self::Config>
where
Self: Sized,
{
Callbacks::empty().with_terminator(AdamEMATerminator::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
core::{Bounds, MaxSteps},
test_functions::Rosenbrock,
};
use approx::assert_relative_eq;
#[test]
fn test_adam() {
let mut solver = Adam::default();
let problem = Rosenbrock { n: 2 };
let starting_values = vec![
[-2.0, 2.0],
[2.0, 2.0],
[2.0, -2.0],
[-2.0, -2.0],
[1.0, 1.0],
[0.0, 0.0],
];
for starting_value in starting_values {
let result = solver
.process(
&problem,
&(),
DVector::from_row_slice(&starting_value),
AdamConfig::default(),
Adam::default_callbacks().with_terminator(MaxSteps(1_000_000)),
)
.unwrap();
assert!(result.message.success());
assert_relative_eq!(result.fx, 0.0, epsilon = Float::EPSILON.cbrt());
}
}
#[test]
fn test_bounded_adam() {
let mut solver = Adam::default();
let problem = Rosenbrock { n: 2 };
let starting_values = vec![
[-2.0, 2.0],
[2.0, 2.0],
[2.0, -2.0],
[-2.0, -2.0],
[1.0, 1.0],
[0.0, 0.0],
];
for starting_value in starting_values {
let result = solver
.process(
&problem,
&(),
DVector::from_row_slice(&starting_value),
AdamConfig::default().with_transform(&Bounds::from([(-4.0, 4.0), (-4.0, 4.0)])),
Adam::default_callbacks().with_terminator(MaxSteps(1_000_000)),
)
.unwrap();
assert!(result.message.success());
assert_relative_eq!(result.fx, 0.0, epsilon = Float::EPSILON.cbrt());
}
}
}