1use crate::metric::{
2 Metric, MetricName,
3 store::{Aggregate, Direction, EventStoreClient, Split},
4};
5
6#[derive(Clone)]
8pub enum StoppingCondition {
9 NoImprovementSince {
11 n_epochs: usize,
13 },
14}
15
16pub trait EarlyStoppingStrategy: Send {
18 fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool;
20}
21
22pub trait CloneEarlyStoppingStrategy: EarlyStoppingStrategy + Send {
24 fn clone_box(&self) -> Box<dyn CloneEarlyStoppingStrategy>;
26}
27
28impl<T> CloneEarlyStoppingStrategy for T
31where
32 T: EarlyStoppingStrategy + Clone + Send + 'static,
33{
34 fn clone_box(&self) -> Box<dyn CloneEarlyStoppingStrategy> {
35 Box::new(self.clone())
36 }
37}
38
39impl Clone for Box<dyn CloneEarlyStoppingStrategy> {
41 fn clone(&self) -> Box<dyn CloneEarlyStoppingStrategy> {
42 self.clone_box()
43 }
44}
45
46#[derive(Clone)]
49pub struct MetricEarlyStoppingStrategy {
50 condition: StoppingCondition,
51 metric_name: MetricName,
52 aggregate: Aggregate,
53 direction: Direction,
54 split: Split,
55 best_epoch: usize,
56 best_value: f64,
57}
58
59impl EarlyStoppingStrategy for MetricEarlyStoppingStrategy {
60 fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool {
61 let current_value =
62 match store.find_metric(&self.metric_name, epoch, self.aggregate, self.split) {
63 Some(value) => value,
64 None => {
65 log::warn!("Can't find metric for early stopping.");
66 return false;
67 }
68 };
69
70 let is_best = match self.direction {
71 Direction::Lowest => current_value < self.best_value,
72 Direction::Highest => current_value > self.best_value,
73 };
74
75 if is_best {
76 log::info!(
77 "New best epoch found {} {}: {}",
78 epoch,
79 self.metric_name,
80 current_value
81 );
82 self.best_value = current_value;
83 self.best_epoch = epoch;
84 return false;
85 }
86
87 match self.condition {
88 StoppingCondition::NoImprovementSince { n_epochs } => {
89 let should_stop = epoch - self.best_epoch >= n_epochs;
90
91 if should_stop {
92 log::info!(
93 "Stopping training loop, no improvement since epoch {}, {}: {}, current \
94 epoch {}, {}: {}",
95 self.best_epoch,
96 self.metric_name,
97 self.best_value,
98 epoch,
99 self.metric_name,
100 current_value
101 );
102 }
103
104 should_stop
105 }
106 }
107 }
108}
109
110impl MetricEarlyStoppingStrategy {
111 pub fn new<Me: Metric>(
118 metric: &Me,
119 aggregate: Aggregate,
120 direction: Direction,
121 split: Split,
122 condition: StoppingCondition,
123 ) -> Self {
124 let init_value = match direction {
125 Direction::Lowest => f64::MAX,
126 Direction::Highest => f64::MIN,
127 };
128
129 Self {
130 metric_name: metric.name(),
131 condition,
132 aggregate,
133 direction,
134 split,
135 best_epoch: 1,
136 best_value: init_value,
137 }
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use std::sync::Arc;
144
145 use crate::{
146 TestBackend,
147 logger::InMemoryMetricLogger,
148 metric::{
149 LossMetric,
150 processor::{
151 MetricsTraining, MinimalEventProcessor,
152 test_utils::{end_epoch, process_train},
153 },
154 store::LogEventStore,
155 },
156 };
157
158 use super::*;
159
160 #[test]
161 fn never_early_stop_while_it_is_improving() {
162 test_early_stopping(
163 1,
164 &[
165 (&[0.5, 0.3], false, "Should not stop first epoch"),
166 (&[0.4, 0.3], false, "Should not stop when improving"),
167 (&[0.3, 0.3], false, "Should not stop when improving"),
168 (&[0.2, 0.3], false, "Should not stop when improving"),
169 ],
170 );
171 }
172
173 #[test]
174 fn early_stop_when_no_improvement_since_two_epochs() {
175 test_early_stopping(
176 2,
177 &[
178 (&[1.0, 0.5], false, "Should not stop first epoch"),
179 (&[0.5, 0.3], false, "Should not stop when improving"),
180 (
181 &[1.0, 3.0],
182 false,
183 "Should not stop first time it gets worse",
184 ),
185 (
186 &[1.0, 2.0],
187 true,
188 "Should stop since two following epochs didn't improve",
189 ),
190 ],
191 );
192 }
193
194 #[test]
195 fn early_stop_when_stays_equal() {
196 test_early_stopping(
197 2,
198 &[
199 (&[0.5, 0.3], false, "Should not stop first epoch"),
200 (
201 &[0.5, 0.3],
202 false,
203 "Should not stop first time it stars the same",
204 ),
205 (
206 &[0.5, 0.3],
207 true,
208 "Should stop since two following epochs didn't improve",
209 ),
210 ],
211 );
212 }
213
214 fn test_early_stopping(n_epochs: usize, data: &[(&[f64], bool, &str)]) {
215 let loss = LossMetric::<TestBackend>::new();
216 let mut early_stopping = MetricEarlyStoppingStrategy::new(
217 &loss,
218 Aggregate::Mean,
219 Direction::Lowest,
220 Split::Train,
221 StoppingCondition::NoImprovementSince { n_epochs },
222 );
223 let mut store = LogEventStore::default();
224 let mut metrics = MetricsTraining::<f64, f64>::default();
225
226 store.register_logger_train(InMemoryMetricLogger::default());
227 metrics.register_train_metric_numeric(loss);
228
229 let store = Arc::new(EventStoreClient::new(store));
230 let mut processor = MinimalEventProcessor::new(metrics, store.clone());
231
232 let mut epoch = 1;
233 for (points, should_start, comment) in data {
234 for point in points.iter() {
235 process_train(&mut processor, *point, epoch);
236 }
237 end_epoch(&mut processor, epoch);
238
239 assert_eq!(
240 *should_start,
241 early_stopping.should_stop(epoch, &store),
242 "{comment}"
243 );
244 epoch += 1;
245 }
246 }
247}