1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
#![deny(missing_docs)]

//! # Error Predictive Learning
//! Black-box learning algorithm using error prediction levels
//!
//! This is a very simple black-box learning algorithm which
//! uses higher order error prediction to improve
//! speed and accuracy of search to find local minima.
//!
//! See paper about [Error Predictive Learning](https://github.com/advancedresearch/path_semantics/blob/master/papers-wip/error-predictive-learning.pdf)
//!
//! ### Error prediction levels
//!
//! In error predictive learning, extra terms are added to the error
//! function such that the search algorithm must learn to predict error,
//! error in predicted error, and so on.
//! This information is used in a non-linear way to adapt search behavior,
//! which in turn affects error prediction etc.
//!
//! This algorithm is useful for numerical function approximation
//! of few variables due to high accuracy.
//!
//! ### Reset intervals
//!
//! In black-box learning, there are no assumptions about the function.
//! This makes it hard to use domain specific optimizations such as Newton's method.
//! The learning algorithm need to build up momentum in other ways.
//!
//! Counter-intuitively, forgetting the momentum from time to time
//! and rebuilding it might improve the search.
//! This is possible because re-learning momentum at a local point is relatively cheap.
//! The learning algorithm can takes advantage of local specific knowledge,
//! to gain the losses from forgetting the momentum.

pub mod utils;

/// Stores training settings.
#[derive(Copy, Clone, Debug)]
pub struct TrainingSettings {
    /// Acceptable accuracy in error.
    pub accuracy_error: f64,
    /// The minimum step value.
    ///
    /// When `error_predictions` is set to `0`, this is used as fixed step.
    pub step: f64,
    /// Maximum number of iterations.
    pub max_iterations: u64,
    /// The number of error prediction terms.
    ///
    /// More terms accelerate the search, but might lead to instability.
    pub error_predictions: usize,
    /// The interval to reset error predictions,
    /// in case they are far off or become unstable.
    pub reset_interval: u64,
    /// A factor greater than zero to prevent under or over-stepping.
    ///
    /// E.g. `0.95`
    ///
    /// This is used because predicted errors does not provide
    /// information about the gradient directly in the domain.
    /// Elasticity is used to estimate the gradient.
    pub elasticity: f64,
    /// Whether to print out result each reset interval.
    pub debug: bool,
}

/// Stores fit data.
#[derive(Clone, Debug)]
pub struct Fit {
    /// The error of the fit.
    pub error: f64,
    /// Weights of best fit (so far).
    pub weights: Vec<f64>,
    /// Error predictions weights.
    pub error_predictions: Vec<f64>,
    /// The number of iterations to produce the result.
    pub iterations: u64,
}

/// Trains to fit a vector of weights on a black-box function returning error.
///
/// Returns `Ok` if acceptable accuracy error was achieved.
/// Returns `Err` if exceeding max iterations or score was unchanged
/// for twice the reset interval.
pub fn train<F: Fn(&[f64]) -> f64>(
    settings: TrainingSettings,
    weights: &[f64],
    f: F
) -> Result<Fit, Fit> {
    let mut ws = vec![0.0; settings.error_predictions + weights.len()];
    for i in 0..weights.len() {
        ws[i + settings.error_predictions] = weights[i];
    }
    if settings.error_predictions > 0 {
        ws[0] = settings.step;
    }
    let eval = |ws: &[f64]| {
        let mut score = f(&ws[settings.error_predictions..]);
        for i in 0..settings.error_predictions {
            score += (score - ws[i]).abs();
        }
        score
    };
    let step = |ws: &[f64], i: usize| {
        settings.elasticity * if i + 1 < settings.error_predictions {
            // Use next error prediction level for change.
            ws[i + 1]
        } else if i + 1 == settings.error_predictions {
            // The last error prediction level uses normal step.
            settings.step
        } else if settings.error_predictions > 0 {
            // Adjust step to predicted error.
            ws[0]
        } else {
            settings.step
        }
    };
    let check = |w: &mut f64, i: usize| {
        if i < settings.error_predictions {
            if *w <= settings.step {*w = settings.step}
        }
    };
    let mut iterations = 0;
    // Keep track of last score to detect unchanged loop.
    let mut last_score: Option<f64> = None;
    let mut last_score_iterations = 0;
    loop {
        // Evaluate score without error predictions.
        let score = f(&ws[settings.error_predictions..]);
        if score <= settings.accuracy_error {
            return Ok(Fit {
                error: score,
                weights: ws[settings.error_predictions..].into(),
                error_predictions: ws[0..settings.error_predictions].into(),
                iterations,
            })
        } else if iterations >= settings.max_iterations ||
            last_score_iterations >= 2 * settings.reset_interval {
            return Err(Fit {
                error: score,
                weights: ws[settings.error_predictions..].into(),
                error_predictions: ws[0..settings.error_predictions].into(),
                iterations,
            })
        }
        if last_score == Some(score) {
            last_score_iterations += 1;
        } else {
            last_score_iterations = 0;
        }
        last_score = Some(score);
        // Reset error predictions.
        if iterations % settings.reset_interval == 0 {
            if settings.debug {
                println!("{:?}", Fit {
                    error: score,
                    weights: ws[settings.error_predictions..].into(),
                    error_predictions: ws[0..settings.error_predictions].into(),
                    iterations,
                });
            }
            for i in 0..settings.error_predictions {
                ws[i] = 0.0;
            }
            ws[0] = settings.step;
        }
        // Change eight weight in either direction and pick the best.
        // This also changes the error prediction weights.
        for i in 0..ws.len() {
            let score = eval(&ws);
            let old = ws[i];
            let step = step(&ws, i);
            ws[i] += step;
            check(&mut ws[i], i);
            let score_up = eval(&ws);
            ws[i] -= 2.0 * step;
            check(&mut ws[i], i);
            let score_down = eval(&ws);
            if score <= score_up && score <= score_down {
                ws[i] = old;
            } else if score_up < score_down {
                ws[i] += 2.0 * step;
            }
        }
        iterations += 1;
    }
}