1use crate::metric::{
2 Metric,
3 store::{Aggregate, Direction, EventStoreClient, Split},
4};
5
6pub enum StoppingCondition {
8 NoImprovementSince {
10 n_epochs: usize,
12 },
13}
14
15pub trait EarlyStoppingStrategy {
17 fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool;
19}
20
21pub struct MetricEarlyStoppingStrategy {
24 condition: StoppingCondition,
25 metric_name: String,
26 aggregate: Aggregate,
27 direction: Direction,
28 split: Split,
29 best_epoch: usize,
30 best_value: f64,
31}
32
33impl EarlyStoppingStrategy for MetricEarlyStoppingStrategy {
34 fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool {
35 let current_value =
36 match store.find_metric(&self.metric_name, epoch, self.aggregate, self.split) {
37 Some(value) => value,
38 None => {
39 log::warn!("Can't find metric for early stopping.");
40 return false;
41 }
42 };
43
44 let is_best = match self.direction {
45 Direction::Lowest => current_value < self.best_value,
46 Direction::Highest => current_value > self.best_value,
47 };
48
49 if is_best {
50 log::info!(
51 "New best epoch found {} {}: {}",
52 epoch,
53 self.metric_name,
54 current_value
55 );
56 self.best_value = current_value;
57 self.best_epoch = epoch;
58 return false;
59 }
60
61 match self.condition {
62 StoppingCondition::NoImprovementSince { n_epochs } => {
63 let should_stop = epoch - self.best_epoch >= n_epochs;
64
65 if should_stop {
66 log::info!(
67 "Stopping training loop, no improvement since epoch {}, {}: {}, current \
68 epoch {}, {}: {}",
69 self.best_epoch,
70 self.metric_name,
71 self.best_value,
72 epoch,
73 self.metric_name,
74 current_value
75 );
76 }
77
78 should_stop
79 }
80 }
81 }
82}
83
84impl MetricEarlyStoppingStrategy {
85 pub fn new<Me: Metric>(
92 metric: &Me,
93 aggregate: Aggregate,
94 direction: Direction,
95 split: Split,
96 condition: StoppingCondition,
97 ) -> Self {
98 let init_value = match direction {
99 Direction::Lowest => f64::MAX,
100 Direction::Highest => f64::MIN,
101 };
102
103 Self {
104 metric_name: metric.name(),
105 condition,
106 aggregate,
107 direction,
108 split,
109 best_epoch: 1,
110 best_value: init_value,
111 }
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use std::sync::Arc;
118
119 use crate::{
120 TestBackend,
121 logger::InMemoryMetricLogger,
122 metric::{
123 LossMetric,
124 processor::{
125 Metrics, MinimalEventProcessor,
126 test_utils::{end_epoch, process_train},
127 },
128 store::LogEventStore,
129 },
130 };
131
132 use super::*;
133
134 #[test]
135 fn never_early_stop_while_it_is_improving() {
136 test_early_stopping(
137 1,
138 &[
139 (&[0.5, 0.3], false, "Should not stop first epoch"),
140 (&[0.4, 0.3], false, "Should not stop when improving"),
141 (&[0.3, 0.3], false, "Should not stop when improving"),
142 (&[0.2, 0.3], false, "Should not stop when improving"),
143 ],
144 );
145 }
146
147 #[test]
148 fn early_stop_when_no_improvement_since_two_epochs() {
149 test_early_stopping(
150 2,
151 &[
152 (&[1.0, 0.5], false, "Should not stop first epoch"),
153 (&[0.5, 0.3], false, "Should not stop when improving"),
154 (
155 &[1.0, 3.0],
156 false,
157 "Should not stop first time it gets worse",
158 ),
159 (
160 &[1.0, 2.0],
161 true,
162 "Should stop since two following epochs didn't improve",
163 ),
164 ],
165 );
166 }
167
168 #[test]
169 fn early_stop_when_stays_equal() {
170 test_early_stopping(
171 2,
172 &[
173 (&[0.5, 0.3], false, "Should not stop first epoch"),
174 (
175 &[0.5, 0.3],
176 false,
177 "Should not stop first time it stars the same",
178 ),
179 (
180 &[0.5, 0.3],
181 true,
182 "Should stop since two following epochs didn't improve",
183 ),
184 ],
185 );
186 }
187
188 fn test_early_stopping(n_epochs: usize, data: &[(&[f64], bool, &str)]) {
189 let loss = LossMetric::<TestBackend>::new();
190 let mut early_stopping = MetricEarlyStoppingStrategy::new(
191 &loss,
192 Aggregate::Mean,
193 Direction::Lowest,
194 Split::Train,
195 StoppingCondition::NoImprovementSince { n_epochs },
196 );
197 let mut store = LogEventStore::default();
198 let mut metrics = Metrics::<f64, f64>::default();
199
200 store.register_logger_train(InMemoryMetricLogger::default());
201 metrics.register_train_metric_numeric(loss);
202
203 let store = Arc::new(EventStoreClient::new(store));
204 let mut processor = MinimalEventProcessor::new(metrics, store.clone());
205
206 let mut epoch = 1;
207 for (points, should_start, comment) in data {
208 for point in points.iter() {
209 process_train(&mut processor, *point, epoch);
210 }
211 end_epoch(&mut processor, epoch);
212
213 assert_eq!(
214 *should_start,
215 early_stopping.should_stop(epoch, &store),
216 "{comment}"
217 );
218 epoch += 1;
219 }
220 }
221}