use log::Level::Trace;
use rand::{SeedableRng, random};
use rand::seq::SliceRandom;
use rand_pcg::Pcg64Mcg;
use types::{Minimizer, Solution, Summation1};
pub struct StochasticGradientDescent {
rng: Pcg64Mcg,
max_iterations: Option<u64>,
mini_batch: usize,
step_width: f64
}
impl StochasticGradientDescent {
pub fn new() -> StochasticGradientDescent {
StochasticGradientDescent {
rng: Pcg64Mcg::new(random()),
max_iterations: None,
mini_batch: 1,
step_width: 0.01
}
}
pub fn seed(&mut self, seed: u64) -> &mut Self {
self.rng = Pcg64Mcg::seed_from_u64(seed);
self
}
pub fn max_iterations(&mut self, max_iterations: Option<u64>) -> &mut Self {
assert!(max_iterations.map_or(true, |max_iterations| max_iterations > 0));
self.max_iterations = max_iterations;
self
}
pub fn mini_batch(&mut self, mini_batch: usize) -> &mut Self {
assert!(mini_batch > 0);
self.mini_batch = mini_batch;
self
}
pub fn step_width(&mut self, step_width: f64) -> &mut Self {
assert!(step_width > 0.0);
self.step_width = step_width;
self
}
}
impl Default for StochasticGradientDescent {
fn default() -> Self {
Self::new()
}
}
impl<F: Summation1> Minimizer<F> for StochasticGradientDescent {
type Solution = Solution;
fn minimize(&self, function: &F, initial_position: Vec<f64>) -> Solution {
let mut position = initial_position;
let mut value = function.value(&position);
if log_enabled!(Trace) {
info!("Starting with y = {:?} for x = {:?}", value, position);
} else {
info!("Starting with y = {:?}", value);
}
let mut iteration = 0;
let mut terms: Vec<_> = (0..function.terms()).collect();
let mut rng = self.rng.clone();
loop {
terms.shuffle(&mut rng);
for batch in terms.chunks(self.mini_batch) {
let gradient = function.partial_gradient(&position, batch);
for (x, g) in position.iter_mut().zip(gradient) {
*x -= self.step_width * g;
}
}
value = function.value(&position);
iteration += 1;
if log_enabled!(Trace) {
debug!("Iteration {:6}: y = {:?}, x = {:?}", iteration, value, position);
} else {
debug!("Iteration {:6}: y = {:?}", iteration, value);
}
let reached_max_iterations = self.max_iterations.map_or(false,
|max_iterations| iteration == max_iterations);
if reached_max_iterations {
info!("Reached maximal number of iterations, stopping optimization");
return Solution::new(position, value);
}
}
}
}