error_predictive_learning/lib.rs
1#![deny(missing_docs)]
2
3//! # Error Predictive Learning
4//! Black-box learning algorithm using error prediction levels
5//!
6//! This is a very simple black-box learning algorithm which
7//! uses higher order error prediction to improve
8//! speed and accuracy of search to find local minima.
9//!
10//! See paper about [Error Predictive Learning](https://github.com/advancedresearch/path_semantics/blob/master/papers-wip/error-predictive-learning.pdf)
11//!
12//! ### Error prediction levels
13//!
14//! In error predictive learning, extra terms are added to the error
15//! function such that the search algorithm must learn to predict error,
16//! error in predicted error, and so on.
17//! This information is used in a non-linear way to adapt search behavior,
18//! which in turn affects error prediction etc.
19//!
20//! This algorithm is useful for numerical function approximation
21//! of few variables due to high accuracy.
22//!
23//! ### Reset intervals
24//!
25//! In black-box learning, there are no assumptions about the function.
26//! This makes it hard to use domain specific optimizations such as Newton's method.
27//! The learning algorithm need to build up momentum in other ways.
28//!
29//! Counter-intuitively, forgetting the momentum from time to time
30//! and rebuilding it might improve the search.
31//! This is possible because re-learning momentum at a local point is relatively cheap.
32//! The learning algorithm can takes advantage of local specific knowledge,
33//! to gain the losses from forgetting the momentum.
34
35pub mod utils;
36
37/// Stores training settings.
38#[derive(Copy, Clone, Debug)]
39pub struct TrainingSettings {
40 /// Acceptable accuracy in error.
41 pub accuracy_error: f64,
42 /// The minimum step value.
43 ///
44 /// When `error_predictions` is set to `0`, this is used as fixed step.
45 pub step: f64,
46 /// Maximum number of iterations.
47 pub max_iterations: u64,
48 /// The number of error prediction terms.
49 ///
50 /// More terms accelerate the search, but might lead to instability.
51 pub error_predictions: usize,
52 /// The interval to reset error predictions,
53 /// in case they are far off or become unstable.
54 pub reset_interval: u64,
55 /// A factor greater than zero to prevent under or over-stepping.
56 ///
57 /// E.g. `0.95`
58 ///
59 /// This is used because predicted errors does not provide
60 /// information about the gradient directly in the domain.
61 /// Elasticity is used to estimate the gradient.
62 pub elasticity: f64,
63 /// Whether to print out result each reset interval.
64 pub debug: bool,
65}
66
67/// Stores fit data.
68#[derive(Clone, Debug)]
69pub struct Fit {
70 /// The error of the fit.
71 pub error: f64,
72 /// Weights of best fit (so far).
73 pub weights: Vec<f64>,
74 /// Error predictions weights.
75 pub error_predictions: Vec<f64>,
76 /// The number of iterations to produce the result.
77 pub iterations: u64,
78}
79
80/// Trains to fit a vector of weights on a black-box function returning error.
81///
82/// Returns `Ok` if acceptable accuracy error was achieved.
83/// Returns `Err` if exceeding max iterations or score was unchanged
84/// for twice the reset interval.
85pub fn train<F: Fn(&[f64]) -> f64>(
86 settings: TrainingSettings,
87 weights: &[f64],
88 f: F
89) -> Result<Fit, Fit> {
90 let mut ws = vec![0.0; settings.error_predictions + weights.len()];
91 for i in 0..weights.len() {
92 ws[i + settings.error_predictions] = weights[i];
93 }
94 if settings.error_predictions > 0 {
95 ws[0] = settings.step;
96 }
97 let eval = |ws: &[f64]| {
98 let mut score = f(&ws[settings.error_predictions..]);
99 for i in 0..settings.error_predictions {
100 score += (score - ws[i]).abs();
101 }
102 score
103 };
104 let step = |ws: &[f64], i: usize| {
105 settings.elasticity * if i + 1 < settings.error_predictions {
106 // Use next error prediction level for change.
107 ws[i + 1]
108 } else if i + 1 == settings.error_predictions {
109 // The last error prediction level uses normal step.
110 settings.step
111 } else if settings.error_predictions > 0 {
112 // Adjust step to predicted error.
113 ws[0]
114 } else {
115 settings.step
116 }
117 };
118 let check = |w: &mut f64, i: usize| {
119 if i < settings.error_predictions {
120 if *w <= settings.step {*w = settings.step}
121 }
122 };
123 let mut iterations = 0;
124 // Keep track of last score to detect unchanged loop.
125 let mut last_score: Option<f64> = None;
126 let mut last_score_iterations = 0;
127 loop {
128 // Evaluate score without error predictions.
129 let score = f(&ws[settings.error_predictions..]);
130 if score <= settings.accuracy_error {
131 return Ok(Fit {
132 error: score,
133 weights: ws[settings.error_predictions..].into(),
134 error_predictions: ws[0..settings.error_predictions].into(),
135 iterations,
136 })
137 } else if iterations >= settings.max_iterations ||
138 last_score_iterations >= 2 * settings.reset_interval {
139 return Err(Fit {
140 error: score,
141 weights: ws[settings.error_predictions..].into(),
142 error_predictions: ws[0..settings.error_predictions].into(),
143 iterations,
144 })
145 }
146 if last_score == Some(score) {
147 last_score_iterations += 1;
148 } else {
149 last_score_iterations = 0;
150 }
151 last_score = Some(score);
152 // Reset error predictions.
153 if iterations % settings.reset_interval == 0 {
154 if settings.debug {
155 println!("{:?}", Fit {
156 error: score,
157 weights: ws[settings.error_predictions..].into(),
158 error_predictions: ws[0..settings.error_predictions].into(),
159 iterations,
160 });
161 }
162 for i in 0..settings.error_predictions {
163 ws[i] = 0.0;
164 }
165 ws[0] = settings.step;
166 }
167 // Change eight weight in either direction and pick the best.
168 // This also changes the error prediction weights.
169 for i in 0..ws.len() {
170 let score = eval(&ws);
171 let old = ws[i];
172 let step = step(&ws, i);
173 ws[i] += step;
174 check(&mut ws[i], i);
175 let score_up = eval(&ws);
176 ws[i] -= 2.0 * step;
177 check(&mut ws[i], i);
178 let score_down = eval(&ws);
179 if score <= score_up && score <= score_down {
180 ws[i] = old;
181 } else if score_up < score_down {
182 ws[i] += 2.0 * step;
183 }
184 }
185 iterations += 1;
186 }
187}