use log::Level::*;
use log::{debug, error, info, log_enabled, trace, warn};
use cpu_time::ProcessTime;
use std::time::SystemTime;
use std::iter::FromIterator;
use rand::distributions::Distribution;
use rand::SeedableRng;
use ndarray::{Array, Dimension};
use rand_xoshiro::Xoshiro256PlusPlus;
use crate::monitor::*;
use crate::types::*;
pub struct BatchSizeInfo {
_step: usize,
large_batch: usize,
mini_batch: usize,
nb_mini_batch_parameter: f64,
step_size: f64,
}
pub struct StochasticControlledGradientDescent {
rng: Xoshiro256PlusPlus,
eta_zero: f64,
m_zero: f64,
mini_batch_size_init: usize,
large_batch_fraction_init: f64,
large_batch_max_fraction: f64,
}
impl StochasticControlledGradientDescent {
pub fn new(
eta_zero: f64,
m_zero: f64,
mini_batch_size_init: usize,
large_batch_fraction_init: f64,
) -> StochasticControlledGradientDescent {
if large_batch_fraction_init > 1. {
warn!("large_batch_size_init > 1. , fraction factor for large_batch size initialization must be < 1. , exiting");
std::process::exit(1);
}
if m_zero > large_batch_fraction_init {
warn!("m_zero > large_batch_size_init fraction , base fraction for nb mini batch should be less than for large_batch_size");
std::process::exit(1);
}
info!(" eta_zero {:2.4E} m_zero {:2.4} \n mini batch size {:?} , large_batch_fraction_init {:2.4E}",
eta_zero, m_zero, mini_batch_size_init, large_batch_fraction_init);
StochasticControlledGradientDescent {
rng: Xoshiro256PlusPlus::seed_from_u64(4664397),
eta_zero: eta_zero,
m_zero: m_zero,
mini_batch_size_init: mini_batch_size_init,
large_batch_fraction_init: large_batch_fraction_init as f64,
large_batch_max_fraction: 0.1,
}
}
pub fn seed(&mut self, seed: [u8; 32]) {
self.rng = Xoshiro256PlusPlus::from_seed(seed);
}
pub fn set_large_batch_max_fraction(&mut self, fraction: f64) {
assert!(0.01 < fraction && fraction < 1.);
self.large_batch_max_fraction = fraction;
}
fn estimate_batch_growing_factor(&self, nb_max_iterations: usize, nbterms: usize) -> f64 {
let batch_growing_factor: f64;
if self.m_zero * (nbterms as f64) < 1. {
error!("m_zero fraction , should be greater than 1./ number of terms in sum");
std::process::exit(1);
}
let log_alfa = (-self.large_batch_fraction_init.ln()) / (2. * nb_max_iterations as f64);
batch_growing_factor = log_alfa.exp();
if batch_growing_factor <= 1. {
println!("batch growing factor shoud be greater than 1. , possibly you can reduce number of iterations ");
}
debug!(
" upper bound for batch_growing_factor {:2.4E}",
batch_growing_factor
);
return batch_growing_factor;
}
fn get_batch_size_at_jstep(
&self,
batch_growing_factor: f64,
nbterms: usize,
j: usize,
) -> BatchSizeInfo {
let alfa_j = batch_growing_factor.powi(j as i32);
let max_large_batch_size;
if nbterms > 100 {
max_large_batch_size = (nbterms as f64 * self.large_batch_max_fraction).ceil() as usize;
} else {
max_large_batch_size = nbterms;
}
let max_mini_batch_size = (nbterms as f64 / 100.).ceil() as usize;
let large_batch_size =
((self.large_batch_fraction_init * (nbterms as f64) * alfa_j * alfa_j).ceil() as usize)
.min(max_large_batch_size);
let mini_batch_size =
((self.mini_batch_size_init as f64 * alfa_j).floor() as usize).min(max_mini_batch_size);
let nb_mini_batch_parameter = self.m_zero * (nbterms as f64) * alfa_j.powf(1.5);
let step_size = self.eta_zero / alfa_j.sqrt();
BatchSizeInfo {
_step: j,
large_batch: large_batch_size,
mini_batch: mini_batch_size,
nb_mini_batch_parameter: nb_mini_batch_parameter,
step_size: step_size,
}
} fn get_nb_small_mini_batches(&self, batch_size_info: &BatchSizeInfo) -> usize {
let m_j = batch_size_info.nb_mini_batch_parameter as f64;
let b_j = batch_size_info.mini_batch as f64;
let mut n_j = (m_j / b_j).ceil() as usize;
n_j = n_j.min(batch_size_info.large_batch);
debug!(
" nb small mini batch {:?} m_j {:2.4E} b_j : {:2.4E} ",
n_j, m_j, b_j
);
return n_j;
}
}
fn sample_without_replacement_reservoir(
size_asked: usize,
in_terms: &[usize],
rng: &mut Xoshiro256PlusPlus,
) -> Vec<usize> {
let mut out_terms = Vec::<usize>::with_capacity(size_asked.min(in_terms.len()));
for i in 0..size_asked {
out_terms.push(in_terms[i]);
}
let mut xsi: f64;
xsi = rand_distr::Standard.sample(rng);
let mut w: f64 = (xsi.ln() / (size_asked as f64)).exp();
let mut s = size_asked - 1;
while s < in_terms.len() {
xsi = rand_distr::Standard.sample(rng);
s = s + (xsi.ln() / (1. - w).ln()).floor() as usize + 1;
if s < in_terms.len() {
xsi = rand_distr::Standard.sample(rng);
let idx = (size_asked as f64 * xsi).floor() as usize;
out_terms[idx] = in_terms[s];
xsi = rand_distr::Standard.sample(rng);
w = w * (xsi.ln() / (size_asked as f64)).exp();
}
}
out_terms
}
impl Default for StochasticControlledGradientDescent {
fn default() -> Self {
StochasticControlledGradientDescent {
rng: Xoshiro256PlusPlus::seed_from_u64(4664397),
eta_zero: 0.1,
m_zero: 0.04,
mini_batch_size_init: 1,
large_batch_fraction_init: 0.1,
large_batch_max_fraction: 0.1,
}
}
}
impl<D: Dimension, F: SummationC1<D>> Minimizer<D, F, usize>
for StochasticControlledGradientDescent
{
type Solution = Solution<D>;
fn minimize(
&self,
function: &F,
initial_position: &Array<f64, D>,
max_iterations: Option<usize>,
) -> Solution<D> {
let cpu_start = ProcessTime::now();
let sys_now = SystemTime::now();
let mut position = initial_position.clone();
let mut value = function.value(&position);
let nb_max_iterations = max_iterations.unwrap();
let mut direction: Array<f64, D> = position.clone();
direction.fill(0.);
if log_enabled!(Info) {
info!("Starting with y = {:e} \n for x = {:?}", value, position);
} else {
info!("Starting with y = {:e}", value);
}
trace!("nb_max_iterations {:?}", nb_max_iterations);
let mut iteration: usize = 0;
let mut rng = self.rng.clone();
let nb_terms = function.terms();
let mut monitoring = IterationRes::<D>::new(nb_max_iterations, SolMode::Last);
let batch_growing_factor =
self.estimate_batch_growing_factor(nb_max_iterations, function.terms());
let mut large_batch_gradient: Array<f64, D> = position.clone();
large_batch_gradient.fill(0.);
let mut mini_batch_gradient_current: Array<f64, D>;
mini_batch_gradient_current = position.clone();
mini_batch_gradient_current.fill(0.);
let mut mini_batch_gradient_origin: Array<f64, D>;
mini_batch_gradient_origin = position.clone();
mini_batch_gradient_origin.fill(0.);
let all_indexes = Vec::<usize>::from_iter::<std::ops::Range<usize>>(0..nb_terms);
loop {
let iter_params =
self.get_batch_size_at_jstep(batch_growing_factor, nb_terms, iteration);
let n_j = self.get_nb_small_mini_batches(&iter_params);
let step_size = iter_params.step_size;
let large_batch_indexes = sample_without_replacement_reservoir(
iter_params.large_batch,
&all_indexes,
&mut rng,
);
trace!("\n iter {:?} got large batch size {:?}, nb mini batch {:?}, mini batch size {:?}, step {:2.4E}",
iteration, large_batch_indexes.len(), n_j, iter_params.mini_batch, iter_params.step_size);
function.mean_partial_gradient(
&position,
&large_batch_indexes,
&mut large_batch_gradient,
);
let position_before_mini_batch = position.clone();
for _k in 0..n_j {
let terms = sample_without_replacement_reservoir(
iter_params.mini_batch,
&all_indexes,
&mut rng,
);
function.mean_partial_gradient(&position, &terms, &mut mini_batch_gradient_current);
function.mean_partial_gradient(
&position_before_mini_batch,
&terms,
&mut mini_batch_gradient_origin,
);
direction = &mini_batch_gradient_current - &mini_batch_gradient_origin
+ &large_batch_gradient;
position = position - step_size * &direction;
} iteration += 1;
value = function.value(&position);
let gradnorm = norm_l2(&direction);
monitoring.push(value, &position, gradnorm);
if log_enabled!(Debug) {
debug!(
"\n Iteration {:?} y = {:2.4E} , gradient norm : {:2.4}",
iteration, value, gradnorm
);
}
if iteration >= nb_max_iterations {
break;
}
} log::info!("\n StochasticControlledGradientDescent::minimize ; sys time(ms) {:?} cpu time(ms) {:?}",
sys_now.elapsed().unwrap().as_millis(),
cpu_start.elapsed().as_millis());
info!(
"Reached maximal number of iterations required {:?}, stopping optimization",
nb_max_iterations
);
let rank = monitoring.check_monoticity();
info!(" monotonous convergence from rank : {:?}", rank);
return Solution::new(position, value);
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reservoir_sampling() {
let mut rng = Xoshiro256PlusPlus::seed_from_u64(4664397);
let nb_asked = 100;
let in_terms = Vec::<usize>::from_iter::<std::ops::Range<usize>>(0..60000);
let selected_terms = sample_without_replacement_reservoir(nb_asked, &in_terms, &mut rng);
assert_eq!(selected_terms.len(), nb_asked);
}
}