use std::array;
use std::time::Instant;
use crate::dual::Dual;
use crate::simd_arr::dense_simd::DenseSimd;
use crate::simd_arr::hybrid_simd::{CriticalityCue, HybridSimd};
use crate::simd_arr::SimdArr;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use super::Minimizer;
#[derive(Clone)]
pub struct AdamMinimizer<
const P: usize,
ExtraData: Sync + Clone,
S: SimdArr<P>,
FG: Fn(&[Dual<P, S>; P], &ExtraData) -> Dual<P, S> + Sync + Clone,
F: Fn(&[f32; P], &ExtraData) -> f32 + Sync + Clone,
ParamTranslate: Fn(&[f32; P], &[f32; P]) -> [f32; P] + Clone,
> {
cost_gradient: FG,
cost: F,
params: [Dual<P, S>; P],
m: [f32; P], v: [f32; P], t: usize, param_translator: ParamTranslate,
extra_data: ExtraData,
last_cost: Option<f32>,
cost_stagnation_value: Option<f32>,
cost_stagnation_time: usize,
cost_stagnation_threshold: f32,
max_cost_stagnation_time: usize,
}
impl<
const P: usize,
ExtraData: Sync + Clone,
FG: Fn(&[Dual<P, DenseSimd<P>>; P], &ExtraData) -> Dual<P, DenseSimd<P>> + Sync + Clone,
F: Fn(&[f32; P], &ExtraData) -> f32 + Sync + Clone,
ParamTranslate: Fn(&[f32; P], &[f32; P]) -> [f32; P] + Clone,
> AdamMinimizer<P, ExtraData, DenseSimd<P>, FG, F, ParamTranslate>
{
pub fn new_dense(
trainable: F,
trainable_gradient: FG,
param_translator: ParamTranslate,
extra_data: ExtraData,
cost_stagnation_threshold: f32,
max_cost_stagnation_time: usize,
) -> Self {
let mut rng = ChaCha8Rng::seed_from_u64(2);
Self {
cost_gradient: trainable_gradient,
cost: trainable,
params: array::from_fn(|i| Dual::new_param(rng.gen::<f32>() - 0.5, i)),
m: [0.0; P],
v: [0.0; P],
t: 0,
param_translator,
extra_data,
last_cost: None,
cost_stagnation_value: None,
cost_stagnation_threshold,
cost_stagnation_time: 0,
max_cost_stagnation_time,
}
}
}
impl<
const P: usize,
const CRITICALITY: usize,
ExtraData: Sync + Clone,
FG: Fn(
&[Dual<P, HybridSimd<P, CRITICALITY>>; P],
&ExtraData,
) -> Dual<P, HybridSimd<P, CRITICALITY>>
+ Sync
+ Clone,
F: Fn(&[f32; P], &ExtraData) -> f32 + Sync + Clone,
ParamTranslate: Fn(&[f32; P], &[f32; P]) -> [f32; P] + Clone,
> AdamMinimizer<P, ExtraData, HybridSimd<P, CRITICALITY>, FG, F, ParamTranslate>
{
pub fn new_hybrid(
_: CriticalityCue<CRITICALITY>,
trainable: F,
trainable_gradient: FG,
param_translator: ParamTranslate,
extra_data: ExtraData,
cost_stagnation_threshold: f32,
max_cost_stagnation_time: usize,
) -> Self {
let mut rng = ChaCha8Rng::seed_from_u64(2);
Self {
cost_gradient: trainable_gradient,
cost: trainable,
params: array::from_fn(|i| Dual::new_param(rng.gen::<f32>() - 0.5, i)),
m: [0.0; P],
v: [0.0; P],
t: 0,
param_translator,
extra_data,
last_cost: None,
cost_stagnation_value: None,
cost_stagnation_threshold,
cost_stagnation_time: 0,
max_cost_stagnation_time,
}
}
}
impl<
const P: usize,
ExtraData: Sync + Clone,
S: SimdArr<P>,
FG: Fn(&[Dual<P, S>; P], &ExtraData) -> Dual<P, S> + Sync + Clone,
F: Fn(&[f32; P], &ExtraData) -> f32 + Sync + Clone,
ParamTranslate: Fn(&[f32; P], &[f32; P]) -> [f32; P] + Clone,
> Minimizer<P> for AdamMinimizer<P, ExtraData, S, FG, F, ParamTranslate>
{
fn get_last_cost(&self) -> Option<f32> {
self.last_cost
}
fn train_step<const VERBOSE: bool>(&mut self, learning_rate: f32) {
let t0 = Instant::now();
let cost = (self.cost_gradient)(&self.params, &self.extra_data);
let raw_gradient = cost.get_gradient();
let gradient_size = f32::max(
raw_gradient.iter().fold(0., |acc, elm| acc + (elm * elm)),
1e-30,
);
self.t += 1;
for i in 0..P {
self.m[i] = 0.9 * self.m[i] + 0.1 * raw_gradient[i]; self.v[i] = 0.999 * self.v[i] + 0.001 * raw_gradient[i].powi(2); }
let m_hat: [f32; P] = array::from_fn(|i| self.m[i] / (1.0 - 0.9_f32.powi(self.t as i32)));
let v_hat: [f32; P] = array::from_fn(|i| self.v[i] / (1.0 - 0.999_f32.powi(self.t as i32)));
let gradient = array::from_fn(|i| -learning_rate * m_hat[i] / (v_hat[i].sqrt() + 1e-8));
let og_parameters = array::from_fn(|i| self.params[i].get_real());
let new_params = (self.param_translator)(&og_parameters, &gradient);
for (i, param) in new_params.iter().enumerate() {
self.params[i].set_real(*param);
}
let new_cost = (self.cost)(&new_params, &self.extra_data);
if let Some(last_cost) = self.last_cost {
if (last_cost - new_cost).abs() < self.cost_stagnation_threshold {
if let Some(stagnation_value) = self.cost_stagnation_value {
if (new_cost - stagnation_value).abs() < self.cost_stagnation_threshold {
self.cost_stagnation_time += 1;
} else {
self.cost_stagnation_time = 0;
self.cost_stagnation_value = Some(new_cost);
}
} else {
self.cost_stagnation_time = 0;
self.cost_stagnation_value = Some(new_cost);
}
} else {
self.cost_stagnation_value = None;
self.cost_stagnation_time = 0;
}
}
self.last_cost = Some(new_cost);
if VERBOSE {
println!(
"Gradient length: {gradient_size:?} - New cost: {} - Time: {}",
self.last_cost.unwrap(),
t0.elapsed().as_secs_f32()
);
}
}
fn get_model_params(&self) -> [f32; P] {
self.params.clone().map(|e| e.get_real())
}
fn set_model_params(&mut self, parameters: [f32; P]) {
let mut i = 0;
self.params = parameters.map(|p| {
i += 1;
Dual::new_param(p, i - 1)
});
}
fn found_local_minima(&self) -> bool {
self.max_cost_stagnation_time < self.cost_stagnation_time
}
}