1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
use log::Level::Trace;
use rand::{SeedableRng, random};
use rand::seq::SliceRandom;
use rand_pcg::Pcg64Mcg;

use types::{Minimizer, Solution, Summation1};


/// Provides _stochastic_ Gradient Descent optimization.
pub struct StochasticGradientDescent {
    rng: Pcg64Mcg,
    max_iterations: Option<u64>,
    mini_batch: usize,
    step_width: f64
}

impl StochasticGradientDescent {
    /// Creates a new `StochasticGradientDescent` optimizer using the following defaults:
    ///
    /// - **`step_width`** = `0.01`
    /// - **`mini_batch`** = `1`
    /// - **`max_iterations`** = `1000`
    ///
    /// The used random number generator is randomly seeded.
    pub fn new() -> StochasticGradientDescent {
        StochasticGradientDescent {
            rng: Pcg64Mcg::new(random()),
            max_iterations: None,
            mini_batch: 1,
            step_width: 0.01
        }
    }

    /// Seeds the random number generator using the supplied `seed`.
    ///
    /// This is useful to create re-producable results.
    pub fn seed(&mut self, seed: u64) -> &mut Self {
        self.rng = Pcg64Mcg::seed_from_u64(seed);
        self
    }

    /// Adjusts the number of maximally run iterations. A value of `None` instructs the
    /// optimizer to ignore the nubmer of iterations.
    pub fn max_iterations(&mut self, max_iterations: Option<u64>) -> &mut Self {
        assert!(max_iterations.map_or(true, |max_iterations| max_iterations > 0));

        self.max_iterations = max_iterations;
        self
    }

    /// Adjusts the mini batch size, i.e., how many terms are considered in one step at most.
    pub fn mini_batch(&mut self, mini_batch: usize) -> &mut Self {
        assert!(mini_batch > 0);

        self.mini_batch = mini_batch;
        self
    }

    /// Adjusts the step size applied for each mini batch.
    pub fn step_width(&mut self, step_width: f64) -> &mut Self {
        assert!(step_width > 0.0);

        self.step_width = step_width;
        self
    }
}

impl Default for StochasticGradientDescent {
    fn default() -> Self {
        Self::new()
    }
}

impl<F: Summation1> Minimizer<F> for StochasticGradientDescent {
    type Solution = Solution;

    fn minimize(&self, function: &F, initial_position: Vec<f64>) -> Solution {
        let mut position = initial_position;
        let mut value = function.value(&position);

        if log_enabled!(Trace) {
            info!("Starting with y = {:?} for x = {:?}", value, position);
        } else {
            info!("Starting with y = {:?}", value);
        }

        let mut iteration = 0;
        let mut terms: Vec<_> = (0..function.terms()).collect();
        let mut rng = self.rng.clone();

        loop {
            // ensure that we don't run into cycles
            terms.shuffle(&mut rng);

            for batch in terms.chunks(self.mini_batch) {
                let gradient = function.partial_gradient(&position, batch);

                // step into the direction of the negative gradient
                for (x, g) in position.iter_mut().zip(gradient) {
                    *x -= self.step_width * g;
                }
            }

            value = function.value(&position);

            iteration += 1;

            if log_enabled!(Trace) {
                debug!("Iteration {:6}: y = {:?}, x = {:?}", iteration, value, position);
            } else {
                debug!("Iteration {:6}: y = {:?}", iteration, value);
            }

            let reached_max_iterations = self.max_iterations.map_or(false,
                |max_iterations| iteration == max_iterations);

            if reached_max_iterations {
                info!("Reached maximal number of iterations, stopping optimization");
                return Solution::new(position, value);
            }
        }
    }
}