1use crate::metric::{
2 store::{Aggregate, Direction, EventStoreClient, Split},
3 Metric,
4};
5
6pub enum StoppingCondition {
8 NoImprovementSince {
10 n_epochs: usize,
12 },
13}
14
15pub trait EarlyStoppingStrategy {
17 fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool;
19}
20
21pub struct MetricEarlyStoppingStrategy {
24 condition: StoppingCondition,
25 metric_name: String,
26 aggregate: Aggregate,
27 direction: Direction,
28 split: Split,
29 best_epoch: usize,
30 best_value: f64,
31}
32
33impl EarlyStoppingStrategy for MetricEarlyStoppingStrategy {
34 fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool {
35 let current_value =
36 match store.find_metric(&self.metric_name, epoch, self.aggregate, self.split) {
37 Some(value) => value,
38 None => {
39 log::warn!("Can't find metric for early stopping.");
40 return false;
41 }
42 };
43
44 let is_best = match self.direction {
45 Direction::Lowest => current_value < self.best_value,
46 Direction::Highest => current_value > self.best_value,
47 };
48
49 if is_best {
50 log::info!(
51 "New best epoch found {} {}: {}",
52 epoch,
53 self.metric_name,
54 current_value
55 );
56 self.best_value = current_value;
57 self.best_epoch = epoch;
58 return false;
59 }
60
61 match self.condition {
62 StoppingCondition::NoImprovementSince { n_epochs } => {
63 let should_stop = epoch - self.best_epoch >= n_epochs;
64
65 if should_stop {
66 log::info!(
67 "Stopping training loop, no improvement since epoch {}, {}: {}, current \
68 epoch {}, {}: {}",
69 self.best_epoch,
70 self.metric_name,
71 self.best_value,
72 epoch,
73 self.metric_name,
74 current_value
75 );
76 }
77
78 should_stop
79 }
80 }
81 }
82}
83
84impl MetricEarlyStoppingStrategy {
85 pub fn new<Me: Metric>(
92 aggregate: Aggregate,
93 direction: Direction,
94 split: Split,
95 condition: StoppingCondition,
96 ) -> Self {
97 let init_value = match direction {
98 Direction::Lowest => f64::MAX,
99 Direction::Highest => f64::MIN,
100 };
101
102 Self {
103 metric_name: Me::NAME.to_string(),
104 condition,
105 aggregate,
106 direction,
107 split,
108 best_epoch: 1,
109 best_value: init_value,
110 }
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use std::sync::Arc;
117
118 use crate::{
119 logger::InMemoryMetricLogger,
120 metric::{
121 processor::{
122 test_utils::{end_epoch, process_train},
123 Metrics, MinimalEventProcessor,
124 },
125 store::LogEventStore,
126 LossMetric,
127 },
128 TestBackend,
129 };
130
131 use super::*;
132
133 #[test]
134 fn never_early_stop_while_it_is_improving() {
135 test_early_stopping(
136 1,
137 &[
138 (&[0.5, 0.3], false, "Should not stop first epoch"),
139 (&[0.4, 0.3], false, "Should not stop when improving"),
140 (&[0.3, 0.3], false, "Should not stop when improving"),
141 (&[0.2, 0.3], false, "Should not stop when improving"),
142 ],
143 );
144 }
145
146 #[test]
147 fn early_stop_when_no_improvement_since_two_epochs() {
148 test_early_stopping(
149 2,
150 &[
151 (&[1.0, 0.5], false, "Should not stop first epoch"),
152 (&[0.5, 0.3], false, "Should not stop when improving"),
153 (
154 &[1.0, 3.0],
155 false,
156 "Should not stop first time it gets worse",
157 ),
158 (
159 &[1.0, 2.0],
160 true,
161 "Should stop since two following epochs didn't improve",
162 ),
163 ],
164 );
165 }
166
167 #[test]
168 fn early_stop_when_stays_equal() {
169 test_early_stopping(
170 2,
171 &[
172 (&[0.5, 0.3], false, "Should not stop first epoch"),
173 (
174 &[0.5, 0.3],
175 false,
176 "Should not stop first time it stars the same",
177 ),
178 (
179 &[0.5, 0.3],
180 true,
181 "Should stop since two following epochs didn't improve",
182 ),
183 ],
184 );
185 }
186
187 fn test_early_stopping(n_epochs: usize, data: &[(&[f64], bool, &str)]) {
188 let mut early_stopping = MetricEarlyStoppingStrategy::new::<LossMetric<TestBackend>>(
189 Aggregate::Mean,
190 Direction::Lowest,
191 Split::Train,
192 StoppingCondition::NoImprovementSince { n_epochs },
193 );
194 let mut store = LogEventStore::default();
195 let mut metrics = Metrics::<f64, f64>::default();
196
197 store.register_logger_train(InMemoryMetricLogger::default());
198 metrics.register_train_metric_numeric(LossMetric::<TestBackend>::new());
199
200 let store = Arc::new(EventStoreClient::new(store));
201 let mut processor = MinimalEventProcessor::new(metrics, store.clone());
202
203 let mut epoch = 1;
204 for (points, should_start, comment) in data {
205 for point in points.iter() {
206 process_train(&mut processor, *point, epoch);
207 }
208 end_epoch(&mut processor, epoch);
209
210 assert_eq!(
211 *should_start,
212 early_stopping.should_stop(epoch, &store),
213 "{comment}"
214 );
215 epoch += 1;
216 }
217 }
218}