1use crate::metric::{
2 Metric, MetricName,
3 store::{Aggregate, Direction, EventStoreClient, Split},
4};
5
6#[derive(Clone)]
8pub enum StoppingCondition {
9 NoImprovementSince {
11 n_epochs: usize,
13 },
14}
15
16pub trait EarlyStoppingStrategy: Send {
18 fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool;
20}
21
22pub trait CloneEarlyStoppingStrategy: EarlyStoppingStrategy + Send {
24 fn clone_box(&self) -> Box<dyn CloneEarlyStoppingStrategy>;
26}
27
28impl<T> CloneEarlyStoppingStrategy for T
31where
32 T: EarlyStoppingStrategy + Clone + Send + 'static,
33{
34 fn clone_box(&self) -> Box<dyn CloneEarlyStoppingStrategy> {
35 Box::new(self.clone())
36 }
37}
38
39impl Clone for Box<dyn CloneEarlyStoppingStrategy> {
41 fn clone(&self) -> Box<dyn CloneEarlyStoppingStrategy> {
42 self.clone_box()
43 }
44}
45
46#[derive(Clone)]
49pub struct MetricEarlyStoppingStrategy {
50 condition: StoppingCondition,
51 metric_name: MetricName,
52 aggregate: Aggregate,
53 direction: Direction,
54 split: Split,
55 best_epoch: usize,
56 best_value: f64,
57 warmup_epochs: Option<usize>,
58}
59
60impl EarlyStoppingStrategy for MetricEarlyStoppingStrategy {
61 fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool {
62 let current_value =
63 match store.find_metric(&self.metric_name, epoch, self.aggregate, self.split) {
64 Some(value) => value,
65 None => {
66 log::warn!("Can't find metric for early stopping.");
67 return false;
68 }
69 };
70
71 let is_best = match self.direction {
72 Direction::Lowest => current_value < self.best_value,
73 Direction::Highest => current_value > self.best_value,
74 };
75
76 if is_best {
77 log::info!(
78 "New best epoch found {} {}: {}",
79 epoch,
80 self.metric_name,
81 current_value
82 );
83 self.best_value = current_value;
84 self.best_epoch = epoch;
85 return false;
86 }
87
88 if let Some(warmup_epochs) = self.warmup_epochs
89 && epoch <= warmup_epochs
90 {
91 return false;
92 }
93
94 match self.condition {
95 StoppingCondition::NoImprovementSince { n_epochs } => {
96 let should_stop = epoch - self.best_epoch >= n_epochs;
97
98 if should_stop {
99 log::info!(
100 "Stopping training loop, no improvement since epoch {}, {}: {}, current \
101 epoch {}, {}: {}",
102 self.best_epoch,
103 self.metric_name,
104 self.best_value,
105 epoch,
106 self.metric_name,
107 current_value
108 );
109 }
110
111 should_stop
112 }
113 }
114 }
115}
116
117impl MetricEarlyStoppingStrategy {
118 pub fn new<Me: Metric>(
125 metric: &Me,
126 aggregate: Aggregate,
127 direction: Direction,
128 split: Split,
129 condition: StoppingCondition,
130 ) -> Self {
131 let init_value = match direction {
132 Direction::Lowest => f64::MAX,
133 Direction::Highest => f64::MIN,
134 };
135
136 Self {
137 metric_name: metric.name(),
138 condition,
139 aggregate,
140 direction,
141 split,
142 best_epoch: 1,
143 best_value: init_value,
144 warmup_epochs: None,
145 }
146 }
147
148 pub fn warmup_epochs(&self) -> Option<usize> {
152 self.warmup_epochs
153 }
154
155 pub fn with_warmup_epochs(self, warmup: Option<usize>) -> Self {
162 Self {
163 warmup_epochs: warmup,
164 ..self
165 }
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use std::sync::Arc;
172
173 use crate::{
174 EventProcessorTraining, TestBackend,
175 logger::InMemoryMetricLogger,
176 metric::{
177 LossMetric,
178 processor::{
179 MetricsTraining, MinimalEventProcessor,
180 test_utils::{end_epoch, process_train},
181 },
182 store::LogEventStore,
183 },
184 };
185
186 use super::*;
187
188 #[test]
189 fn never_early_stop_while_it_is_improving() {
190 test_early_stopping(
191 None,
192 1,
193 &[
194 (&[0.5, 0.3], false, "Should not stop first epoch"),
195 (&[0.4, 0.3], false, "Should not stop when improving"),
196 (&[0.3, 0.3], false, "Should not stop when improving"),
197 (&[0.2, 0.3], false, "Should not stop when improving"),
198 ],
199 );
200 }
201
202 #[test]
203 fn early_stop_when_no_improvement_since_two_epochs() {
204 test_early_stopping(
205 None,
206 2,
207 &[
208 (&[1.0, 0.5], false, "Should not stop first epoch"),
209 (&[0.5, 0.3], false, "Should not stop when improving"),
210 (
211 &[1.0, 3.0],
212 false,
213 "Should not stop first time it gets worse",
214 ),
215 (
216 &[1.0, 2.0],
217 true,
218 "Should stop since two following epochs didn't improve",
219 ),
220 ],
221 );
222 }
223
224 #[test]
225 fn early_stopping_with_warmup() {
226 test_early_stopping(
227 Some(3),
228 2,
229 &[
230 (&[1.0, 0.5], false, "Should not stop during warmup"),
231 (&[1.0, 0.5], false, "Should not stop during warmup"),
232 (&[1.0, 0.5], false, "Should not stop during warmup"),
233 (
234 &[1.0, 0.5],
235 true,
236 "Should stop when not improving after warmup",
237 ),
238 ],
239 )
240 }
241
242 #[test]
243 fn early_stop_when_stays_equal() {
244 test_early_stopping(
245 None,
246 2,
247 &[
248 (&[0.5, 0.3], false, "Should not stop first epoch"),
249 (
250 &[0.5, 0.3],
251 false,
252 "Should not stop first time it stars the same",
253 ),
254 (
255 &[0.5, 0.3],
256 true,
257 "Should stop since two following epochs didn't improve",
258 ),
259 ],
260 );
261 }
262
263 fn test_early_stopping(warmup: Option<usize>, n_epochs: usize, data: &[(&[f64], bool, &str)]) {
264 let loss = LossMetric::<TestBackend>::new();
265 let mut early_stopping = MetricEarlyStoppingStrategy::new(
266 &loss,
267 Aggregate::Mean,
268 Direction::Lowest,
269 Split::Train,
270 StoppingCondition::NoImprovementSince { n_epochs },
271 )
272 .with_warmup_epochs(warmup);
273 let mut store = LogEventStore::default();
274 let mut metrics = MetricsTraining::<f64, f64>::default();
275
276 store.register_logger(InMemoryMetricLogger::default());
277 metrics.register_train_metric_numeric(loss);
278
279 let store = Arc::new(EventStoreClient::new(store));
280 let mut processor = MinimalEventProcessor::new(metrics, store.clone());
281
282 let mut epoch = 1;
283 processor.process_train(crate::LearnerEvent::Start);
284 for (points, should_start, comment) in data {
285 for point in points.iter() {
286 process_train(&mut processor, *point, epoch);
287 }
288 end_epoch(&mut processor, epoch);
289
290 assert_eq!(
291 *should_start,
292 early_stopping.should_stop(epoch, &store),
293 "{comment}"
294 );
295 epoch += 1;
296 }
297 }
298}