error_predictive_learning/
lib.rs

1#![deny(missing_docs)]
2
3//! # Error Predictive Learning
4//! Black-box learning algorithm using error prediction levels
5//!
6//! This is a very simple black-box learning algorithm which
7//! uses higher order error prediction to improve
8//! speed and accuracy of search to find local minima.
9//!
10//! See paper about [Error Predictive Learning](https://github.com/advancedresearch/path_semantics/blob/master/papers-wip/error-predictive-learning.pdf)
11//!
12//! ### Error prediction levels
13//!
14//! In error predictive learning, extra terms are added to the error
15//! function such that the search algorithm must learn to predict error,
16//! error in predicted error, and so on.
17//! This information is used in a non-linear way to adapt search behavior,
18//! which in turn affects error prediction etc.
19//!
20//! This algorithm is useful for numerical function approximation
21//! of few variables due to high accuracy.
22//!
23//! ### Reset intervals
24//!
25//! In black-box learning, there are no assumptions about the function.
26//! This makes it hard to use domain specific optimizations such as Newton's method.
27//! The learning algorithm need to build up momentum in other ways.
28//!
29//! Counter-intuitively, forgetting the momentum from time to time
30//! and rebuilding it might improve the search.
31//! This is possible because re-learning momentum at a local point is relatively cheap.
32//! The learning algorithm can takes advantage of local specific knowledge,
33//! to gain the losses from forgetting the momentum.
34
35pub mod utils;
36
37/// Stores training settings.
38#[derive(Copy, Clone, Debug)]
39pub struct TrainingSettings {
40    /// Acceptable accuracy in error.
41    pub accuracy_error: f64,
42    /// The minimum step value.
43    ///
44    /// When `error_predictions` is set to `0`, this is used as fixed step.
45    pub step: f64,
46    /// Maximum number of iterations.
47    pub max_iterations: u64,
48    /// The number of error prediction terms.
49    ///
50    /// More terms accelerate the search, but might lead to instability.
51    pub error_predictions: usize,
52    /// The interval to reset error predictions,
53    /// in case they are far off or become unstable.
54    pub reset_interval: u64,
55    /// A factor greater than zero to prevent under or over-stepping.
56    ///
57    /// E.g. `0.95`
58    ///
59    /// This is used because predicted errors does not provide
60    /// information about the gradient directly in the domain.
61    /// Elasticity is used to estimate the gradient.
62    pub elasticity: f64,
63    /// Whether to print out result each reset interval.
64    pub debug: bool,
65}
66
67/// Stores fit data.
68#[derive(Clone, Debug)]
69pub struct Fit {
70    /// The error of the fit.
71    pub error: f64,
72    /// Weights of best fit (so far).
73    pub weights: Vec<f64>,
74    /// Error predictions weights.
75    pub error_predictions: Vec<f64>,
76    /// The number of iterations to produce the result.
77    pub iterations: u64,
78}
79
80/// Trains to fit a vector of weights on a black-box function returning error.
81///
82/// Returns `Ok` if acceptable accuracy error was achieved.
83/// Returns `Err` if exceeding max iterations or score was unchanged
84/// for twice the reset interval.
85pub fn train<F: Fn(&[f64]) -> f64>(
86    settings: TrainingSettings,
87    weights: &[f64],
88    f: F
89) -> Result<Fit, Fit> {
90    let mut ws = vec![0.0; settings.error_predictions + weights.len()];
91    for i in 0..weights.len() {
92        ws[i + settings.error_predictions] = weights[i];
93    }
94    if settings.error_predictions > 0 {
95        ws[0] = settings.step;
96    }
97    let eval = |ws: &[f64]| {
98        let mut score = f(&ws[settings.error_predictions..]);
99        for i in 0..settings.error_predictions {
100            score += (score - ws[i]).abs();
101        }
102        score
103    };
104    let step = |ws: &[f64], i: usize| {
105        settings.elasticity * if i + 1 < settings.error_predictions {
106            // Use next error prediction level for change.
107            ws[i + 1]
108        } else if i + 1 == settings.error_predictions {
109            // The last error prediction level uses normal step.
110            settings.step
111        } else if settings.error_predictions > 0 {
112            // Adjust step to predicted error.
113            ws[0]
114        } else {
115            settings.step
116        }
117    };
118    let check = |w: &mut f64, i: usize| {
119        if i < settings.error_predictions {
120            if *w <= settings.step {*w = settings.step}
121        }
122    };
123    let mut iterations = 0;
124    // Keep track of last score to detect unchanged loop.
125    let mut last_score: Option<f64> = None;
126    let mut last_score_iterations = 0;
127    loop {
128        // Evaluate score without error predictions.
129        let score = f(&ws[settings.error_predictions..]);
130        if score <= settings.accuracy_error {
131            return Ok(Fit {
132                error: score,
133                weights: ws[settings.error_predictions..].into(),
134                error_predictions: ws[0..settings.error_predictions].into(),
135                iterations,
136            })
137        } else if iterations >= settings.max_iterations ||
138            last_score_iterations >= 2 * settings.reset_interval {
139            return Err(Fit {
140                error: score,
141                weights: ws[settings.error_predictions..].into(),
142                error_predictions: ws[0..settings.error_predictions].into(),
143                iterations,
144            })
145        }
146        if last_score == Some(score) {
147            last_score_iterations += 1;
148        } else {
149            last_score_iterations = 0;
150        }
151        last_score = Some(score);
152        // Reset error predictions.
153        if iterations % settings.reset_interval == 0 {
154            if settings.debug {
155                println!("{:?}", Fit {
156                    error: score,
157                    weights: ws[settings.error_predictions..].into(),
158                    error_predictions: ws[0..settings.error_predictions].into(),
159                    iterations,
160                });
161            }
162            for i in 0..settings.error_predictions {
163                ws[i] = 0.0;
164            }
165            ws[0] = settings.step;
166        }
167        // Change eight weight in either direction and pick the best.
168        // This also changes the error prediction weights.
169        for i in 0..ws.len() {
170            let score = eval(&ws);
171            let old = ws[i];
172            let step = step(&ws, i);
173            ws[i] += step;
174            check(&mut ws[i], i);
175            let score_up = eval(&ws);
176            ws[i] -= 2.0 * step;
177            check(&mut ws[i], i);
178            let score_down = eval(&ws);
179            if score <= score_up && score <= score_down {
180                ws[i] = old;
181            } else if score_up < score_down {
182                ws[i] += 2.0 * step;
183            }
184        }
185        iterations += 1;
186    }
187}