1use crate::memory::*;
2
3pub enum AbortStrategy<T: Primitive> {
6 NoImprovement { threshold: T },
11 NoImprovementForXIterations { x: usize, threshold: T, abort_on_negative: bool },
19}
20impl<T: Primitive> AbortStrategy<T> {
21 pub(crate) fn create_logic(&self) -> Box<dyn AbortStrategyLogic<T>> {
22 match *self {
23 AbortStrategy::NoImprovementForXIterations {
24 x,
25 threshold,
26 abort_on_negative,
27 } => Box::new(NoImprovementForXIterationsLogic {
28 x,
29 threshold,
30 abort_on_negative,
31 prev_error: T::infinity(),
32 no_improvement_counter: 0,
33 }),
34 AbortStrategy::NoImprovement { threshold } => Box::new(NoImprovementLogic {
35 threshold,
36 prev_error: T::infinity(),
37 }),
38 }
39 }
40}
41
42pub(crate) trait AbortStrategyLogic<T: Primitive> {
43 fn next(&mut self, error: T) -> bool;
50}
51
52pub(crate) struct NoImprovementLogic<T: Primitive> {
53 threshold: T,
54 prev_error: T,
55}
56impl<T: Primitive> AbortStrategyLogic<T> for NoImprovementLogic<T> {
57 fn next(&mut self, error: T) -> bool {
58 let improvement = self.prev_error - error;
59 self.prev_error = error;
60 improvement > self.threshold
61 }
62}
63
64pub(crate) struct NoImprovementForXIterationsLogic<T: Primitive> {
65 x: usize,
66 threshold: T,
67 abort_on_negative: bool,
68 prev_error: T,
69 no_improvement_counter: usize,
70}
71impl<T: Primitive> AbortStrategyLogic<T> for NoImprovementForXIterationsLogic<T> {
72 fn next(&mut self, error: T) -> bool {
73 let improvement = self.prev_error - error;
74 self.prev_error = error;
75 if self.abort_on_negative && improvement < T::zero() {
76 return false;
78 }
79 if improvement > self.threshold {
80 self.no_improvement_counter = 0;
82 } else {
83 self.no_improvement_counter += 1;
85 }
86 self.no_improvement_counter < self.x
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 #[test]
95 fn test_no_improvement_f32() { test_no_improvement::<f32>(); }
96 #[test]
97 fn test_no_improvement_f64() { test_no_improvement::<f64>(); }
98
99 fn test_no_improvement<T: Primitive>() {
100 {
101 let mut abort_strategy = AbortStrategy::NoImprovement {
102 threshold: T::from(0.0005).unwrap(),
103 }
104 .create_logic();
105 assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
106 assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), false);
107 }
108 {
109 let mut abort_strategy = AbortStrategy::NoImprovement {
110 threshold: T::from(0.0005).unwrap(),
111 }
112 .create_logic();
113 assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
114 assert_eq!(abort_strategy.next(T::from(2999.99959).unwrap()), false);
115 }
116 {
117 let mut abort_strategy = AbortStrategy::NoImprovement {
118 threshold: T::from(0.0005).unwrap(),
119 }
120 .create_logic();
121 assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
122 assert_eq!(abort_strategy.next(T::from(2999.99935).unwrap()), true);
123 }
124 {
125 let mut abort_strategy = AbortStrategy::NoImprovement {
126 threshold: T::from(0.0005).unwrap(),
127 }
128 .create_logic();
129 assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
130 assert_eq!(abort_strategy.next(T::from(2000.0).unwrap()), true);
131 assert_eq!(abort_strategy.next(T::from(1999.99).unwrap()), true);
132 assert_eq!(abort_strategy.next(T::from(1999.99999999).unwrap()), false);
133 }
134 }
135
136 #[test]
137 fn test_no_improvement_for_x_iterations_f32() { test_no_improvement_for_x_iterations::<f32>(); }
138
139 #[test]
140 fn test_no_improvement_for_x_iterations_f64() { test_no_improvement_for_x_iterations::<f64>(); }
141
142 fn test_no_improvement_for_x_iterations<T: Primitive>() {
143 {
144 let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
145 x: 1,
146 threshold: T::from(0.0005).unwrap(),
147 abort_on_negative: false,
148 }
149 .create_logic();
150 assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
151 assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), false);
152 }
153 {
154 let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
155 x: 1,
156 threshold: T::from(0.0005).unwrap(),
157 abort_on_negative: false,
158 }
159 .create_logic();
160 assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
161 assert_eq!(abort_strategy.next(T::from(2999.99959).unwrap()), false);
162 }
163 {
164 let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
165 x: 1,
166 threshold: T::from(0.0005).unwrap(),
167 abort_on_negative: false,
168 }
169 .create_logic();
170 assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
171 assert_eq!(abort_strategy.next(T::from(2999.99935).unwrap()), true);
172 }
173 {
174 let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
175 x: 1,
176 threshold: T::from(0.0005).unwrap(),
177 abort_on_negative: false,
178 }
179 .create_logic();
180 assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
181 assert_eq!(abort_strategy.next(T::from(2000.0).unwrap()), true);
182 assert_eq!(abort_strategy.next(T::from(1999.99).unwrap()), true);
183 assert_eq!(abort_strategy.next(T::from(1999.99999999).unwrap()), false);
184 }
185 {
187 let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
188 x: 1,
189 threshold: T::from(0.0005).unwrap(),
190 abort_on_negative: true,
191 }
192 .create_logic();
193 assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
194 assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), false);
195 }
196 {
197 let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
198 x: 1,
199 threshold: T::from(0.0005).unwrap(),
200 abort_on_negative: true,
201 }
202 .create_logic();
203 assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
204 assert_eq!(abort_strategy.next(T::from(2999.99959).unwrap()), false);
205 }
206 {
207 let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
208 x: 1,
209 threshold: T::from(0.0005).unwrap(),
210 abort_on_negative: true,
211 }
212 .create_logic();
213 assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
214 assert_eq!(abort_strategy.next(T::from(2999.99935).unwrap()), true);
215 }
216 {
217 let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
218 x: 1,
219 threshold: T::from(0.0005).unwrap(),
220 abort_on_negative: true,
221 }
222 .create_logic();
223 assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
224 assert_eq!(abort_strategy.next(T::from(2000.0).unwrap()), true);
225 assert_eq!(abort_strategy.next(T::from(1999.99).unwrap()), true);
226 assert_eq!(abort_strategy.next(T::from(1999.99999999).unwrap()), false);
227 }
228 {
230 let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
231 x: 2,
232 threshold: T::from(0.0005).unwrap(),
233 abort_on_negative: true,
234 }
235 .create_logic();
236 assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
237 assert_eq!(abort_strategy.next(T::from(3001.0).unwrap()), false);
238 }
239 {
240 let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
242 x: 2,
243 threshold: T::from(0.0005).unwrap(),
244 abort_on_negative: true,
245 }
246 .create_logic();
247 assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
248 assert_eq!(abort_strategy.next(T::from(3000.0004).unwrap()), false);
249 }
250 {
251 let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
252 x: 2,
253 threshold: T::from(0.0005).unwrap(),
254 abort_on_negative: true,
255 }
256 .create_logic();
257 assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
258 assert_eq!(abort_strategy.next(T::from(3000.0007).unwrap()), false);
259 }
260
261 {
263 let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
264 x: 2,
265 threshold: T::from(0.0005).unwrap(),
266 abort_on_negative: false,
267 }
268 .create_logic();
269 assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
270 assert_eq!(abort_strategy.next(T::from(2000.0).unwrap()), true);
271 assert_eq!(abort_strategy.next(T::from(2000.0).unwrap()), true);
272 assert_eq!(abort_strategy.next(T::from(1999.0).unwrap()), true);
273 assert_eq!(abort_strategy.next(T::from(1999.0).unwrap()), true);
274 assert_eq!(abort_strategy.next(T::from(1999.0).unwrap()), false);
275 }
276 {
277 let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
279 x: 2,
280 threshold: T::from(0.0005).unwrap(),
281 abort_on_negative: true,
282 }
283 .create_logic();
284 assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
285 assert_eq!(abort_strategy.next(T::from(2000.0).unwrap()), true);
286 assert_eq!(abort_strategy.next(T::from(2000.0).unwrap()), true);
287 assert_eq!(abort_strategy.next(T::from(1999.0).unwrap()), true);
288 assert_eq!(abort_strategy.next(T::from(1999.0).unwrap()), true);
289 assert_eq!(abort_strategy.next(T::from(1999.0).unwrap()), false);
290 }
291 {
292 let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
294 x: 2,
295 threshold: T::from(0.0005).unwrap(),
296 abort_on_negative: true,
297 }
298 .create_logic();
299 assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
300 assert_eq!(abort_strategy.next(T::from(2000.0).unwrap()), true);
301 assert_eq!(abort_strategy.next(T::from(2000.0).unwrap()), true);
302 assert_eq!(abort_strategy.next(T::from(1999.0).unwrap()), true);
303 assert_eq!(abort_strategy.next(T::from(2999.0).unwrap()), false);
304 }
305 }
306}