kmeans/
abort_strategy.rs

1use crate::memory::*;
2
3/// Enum with possible abort strategies.
4/// These strategies specify when a running iteration (with the k-means calculation) is aborted.
5pub enum AbortStrategy<T: Primitive> {
6    /// This strategy aborts the calculation directly after an iteration produced no improvement where `improvement > threshold`
7    /// for the first time.
8    /// ## Fields:
9    /// - **threshold**: Threshold, used to detect an improvement (`improvement > threshold`)
10    NoImprovement { threshold: T },
11    /// This strategy aborts the calculation, when there have not been any improvements after **x** iterations,
12    /// where `improvement > threshold`.
13    /// ## Fields:
14    /// - **x**: The amount of consecutive without improvement, after which the calculation is aborted
15    /// - **threshold**: Threshold, used to detect an improvement (`improvement > threshold`)
16    /// - **abort_on_negative**: Specifies whether the strategy instantly aborts when a negative improvement occured (**true**), or if
17    ///   negative improvements are handled as "no improvements" (**false**).
18    NoImprovementForXIterations { x: usize, threshold: T, abort_on_negative: bool },
19}
20impl<T: Primitive> AbortStrategy<T> {
21    pub(crate) fn create_logic(&self) -> Box<dyn AbortStrategyLogic<T>> {
22        match *self {
23            AbortStrategy::NoImprovementForXIterations {
24                x,
25                threshold,
26                abort_on_negative,
27            } => Box::new(NoImprovementForXIterationsLogic {
28                x,
29                threshold,
30                abort_on_negative,
31                prev_error: T::infinity(),
32                no_improvement_counter: 0,
33            }),
34            AbortStrategy::NoImprovement { threshold } => Box::new(NoImprovementLogic {
35                threshold,
36                prev_error: T::infinity(),
37            }),
38        }
39    }
40}
41
42pub(crate) trait AbortStrategyLogic<T: Primitive> {
43    /// Function that has to be called once an iteration of the calculation ended, a new error was calculated.
44    /// ## Arguments
45    /// - **error**: The new **error (distsum), after an iteration
46    /// ## Returns
47    /// - **true** if the calculation should continue
48    /// - **false** if the calculation should abort
49    fn next(&mut self, error: T) -> bool;
50}
51
52pub(crate) struct NoImprovementLogic<T: Primitive> {
53    threshold: T,
54    prev_error: T,
55}
56impl<T: Primitive> AbortStrategyLogic<T> for NoImprovementLogic<T> {
57    fn next(&mut self, error: T) -> bool {
58        let improvement = self.prev_error - error;
59        self.prev_error = error;
60        improvement > self.threshold
61    }
62}
63
64pub(crate) struct NoImprovementForXIterationsLogic<T: Primitive> {
65    x: usize,
66    threshold: T,
67    abort_on_negative: bool,
68    prev_error: T,
69    no_improvement_counter: usize,
70}
71impl<T: Primitive> AbortStrategyLogic<T> for NoImprovementForXIterationsLogic<T> {
72    fn next(&mut self, error: T) -> bool {
73        let improvement = self.prev_error - error;
74        self.prev_error = error;
75        if self.abort_on_negative && improvement < T::zero() {
76            // Negative improvement, and instant abort is requested
77            return false;
78        }
79        if improvement > self.threshold {
80            // positive improvement: reset no-improv-counter
81            self.no_improvement_counter = 0;
82        } else {
83            // Still no improvement, count 1 up
84            self.no_improvement_counter += 1;
85        }
86        self.no_improvement_counter < self.x
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    #[test]
95    fn test_no_improvement_f32() { test_no_improvement::<f32>(); }
96    #[test]
97    fn test_no_improvement_f64() { test_no_improvement::<f64>(); }
98
99    fn test_no_improvement<T: Primitive>() {
100        {
101            let mut abort_strategy = AbortStrategy::NoImprovement {
102                threshold: T::from(0.0005).unwrap(),
103            }
104            .create_logic();
105            assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
106            assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), false);
107        }
108        {
109            let mut abort_strategy = AbortStrategy::NoImprovement {
110                threshold: T::from(0.0005).unwrap(),
111            }
112            .create_logic();
113            assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
114            assert_eq!(abort_strategy.next(T::from(2999.99959).unwrap()), false);
115        }
116        {
117            let mut abort_strategy = AbortStrategy::NoImprovement {
118                threshold: T::from(0.0005).unwrap(),
119            }
120            .create_logic();
121            assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
122            assert_eq!(abort_strategy.next(T::from(2999.99935).unwrap()), true);
123        }
124        {
125            let mut abort_strategy = AbortStrategy::NoImprovement {
126                threshold: T::from(0.0005).unwrap(),
127            }
128            .create_logic();
129            assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
130            assert_eq!(abort_strategy.next(T::from(2000.0).unwrap()), true);
131            assert_eq!(abort_strategy.next(T::from(1999.99).unwrap()), true);
132            assert_eq!(abort_strategy.next(T::from(1999.99999999).unwrap()), false);
133        }
134    }
135
136    #[test]
137    fn test_no_improvement_for_x_iterations_f32() { test_no_improvement_for_x_iterations::<f32>(); }
138
139    #[test]
140    fn test_no_improvement_for_x_iterations_f64() { test_no_improvement_for_x_iterations::<f64>(); }
141
142    fn test_no_improvement_for_x_iterations<T: Primitive>() {
143        {
144            let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
145                x: 1,
146                threshold: T::from(0.0005).unwrap(),
147                abort_on_negative: false,
148            }
149            .create_logic();
150            assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
151            assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), false);
152        }
153        {
154            let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
155                x: 1,
156                threshold: T::from(0.0005).unwrap(),
157                abort_on_negative: false,
158            }
159            .create_logic();
160            assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
161            assert_eq!(abort_strategy.next(T::from(2999.99959).unwrap()), false);
162        }
163        {
164            let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
165                x: 1,
166                threshold: T::from(0.0005).unwrap(),
167                abort_on_negative: false,
168            }
169            .create_logic();
170            assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
171            assert_eq!(abort_strategy.next(T::from(2999.99935).unwrap()), true);
172        }
173        {
174            let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
175                x: 1,
176                threshold: T::from(0.0005).unwrap(),
177                abort_on_negative: false,
178            }
179            .create_logic();
180            assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
181            assert_eq!(abort_strategy.next(T::from(2000.0).unwrap()), true);
182            assert_eq!(abort_strategy.next(T::from(1999.99).unwrap()), true);
183            assert_eq!(abort_strategy.next(T::from(1999.99999999).unwrap()), false);
184        }
185        // ABORT_ON_NEGATIVE (without negative improvements)
186        {
187            let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
188                x: 1,
189                threshold: T::from(0.0005).unwrap(),
190                abort_on_negative: true,
191            }
192            .create_logic();
193            assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
194            assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), false);
195        }
196        {
197            let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
198                x: 1,
199                threshold: T::from(0.0005).unwrap(),
200                abort_on_negative: true,
201            }
202            .create_logic();
203            assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
204            assert_eq!(abort_strategy.next(T::from(2999.99959).unwrap()), false);
205        }
206        {
207            let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
208                x: 1,
209                threshold: T::from(0.0005).unwrap(),
210                abort_on_negative: true,
211            }
212            .create_logic();
213            assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
214            assert_eq!(abort_strategy.next(T::from(2999.99935).unwrap()), true);
215        }
216        {
217            let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
218                x: 1,
219                threshold: T::from(0.0005).unwrap(),
220                abort_on_negative: true,
221            }
222            .create_logic();
223            assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
224            assert_eq!(abort_strategy.next(T::from(2000.0).unwrap()), true);
225            assert_eq!(abort_strategy.next(T::from(1999.99).unwrap()), true);
226            assert_eq!(abort_strategy.next(T::from(1999.99999999).unwrap()), false);
227        }
228        // ABORT_ON_NEGATIVE (with negative improvements)
229        {
230            let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
231                x: 2,
232                threshold: T::from(0.0005).unwrap(),
233                abort_on_negative: true,
234            }
235            .create_logic();
236            assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
237            assert_eq!(abort_strategy.next(T::from(3001.0).unwrap()), false);
238        }
239        {
240            // Should abort on negative improvement, even ifs absolute value < threshold
241            let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
242                x: 2,
243                threshold: T::from(0.0005).unwrap(),
244                abort_on_negative: true,
245            }
246            .create_logic();
247            assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
248            assert_eq!(abort_strategy.next(T::from(3000.0004).unwrap()), false);
249        }
250        {
251            let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
252                x: 2,
253                threshold: T::from(0.0005).unwrap(),
254                abort_on_negative: true,
255            }
256            .create_logic();
257            assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
258            assert_eq!(abort_strategy.next(T::from(3000.0007).unwrap()), false);
259        }
260
261        // X != 1
262        {
263            let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
264                x: 2,
265                threshold: T::from(0.0005).unwrap(),
266                abort_on_negative: false,
267            }
268            .create_logic();
269            assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
270            assert_eq!(abort_strategy.next(T::from(2000.0).unwrap()), true);
271            assert_eq!(abort_strategy.next(T::from(2000.0).unwrap()), true);
272            assert_eq!(abort_strategy.next(T::from(1999.0).unwrap()), true);
273            assert_eq!(abort_strategy.next(T::from(1999.0).unwrap()), true);
274            assert_eq!(abort_strategy.next(T::from(1999.0).unwrap()), false);
275        }
276        {
277            // Same as directly above, but with abort_on_negative = true
278            let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
279                x: 2,
280                threshold: T::from(0.0005).unwrap(),
281                abort_on_negative: true,
282            }
283            .create_logic();
284            assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
285            assert_eq!(abort_strategy.next(T::from(2000.0).unwrap()), true);
286            assert_eq!(abort_strategy.next(T::from(2000.0).unwrap()), true);
287            assert_eq!(abort_strategy.next(T::from(1999.0).unwrap()), true);
288            assert_eq!(abort_strategy.next(T::from(1999.0).unwrap()), true);
289            assert_eq!(abort_strategy.next(T::from(1999.0).unwrap()), false);
290        }
291        {
292            // Negative improvement before no_improvement_counter == 2
293            let mut abort_strategy = AbortStrategy::NoImprovementForXIterations {
294                x: 2,
295                threshold: T::from(0.0005).unwrap(),
296                abort_on_negative: true,
297            }
298            .create_logic();
299            assert_eq!(abort_strategy.next(T::from(3000.0).unwrap()), true);
300            assert_eq!(abort_strategy.next(T::from(2000.0).unwrap()), true);
301            assert_eq!(abort_strategy.next(T::from(2000.0).unwrap()), true);
302            assert_eq!(abort_strategy.next(T::from(1999.0).unwrap()), true);
303            assert_eq!(abort_strategy.next(T::from(2999.0).unwrap()), false);
304        }
305    }
306}