burn_train/learner/
early_stopping.rs

1use crate::metric::{
2    Metric, MetricName,
3    store::{Aggregate, Direction, EventStoreClient, Split},
4};
5
6/// The condition that [early stopping strategies](EarlyStoppingStrategy) should follow.
7#[derive(Clone)]
8pub enum StoppingCondition {
9    /// When no improvement has happened since the given number of epochs.
10    NoImprovementSince {
11        /// The number of epochs allowed to worsen before it gets better.
12        n_epochs: usize,
13    },
14}
15
16/// A strategy that checks if the training should be stopped.
17pub trait EarlyStoppingStrategy: Send {
18    /// Update its current state and returns if the training should be stopped.
19    fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool;
20}
21
22/// A helper trait to provide type-erased cloning.
23pub trait CloneEarlyStoppingStrategy: EarlyStoppingStrategy + Send {
24    /// Clone into a boxed trait object.
25    fn clone_box(&self) -> Box<dyn CloneEarlyStoppingStrategy>;
26}
27
28/// Blanket-implement `CloneEarlyStoppingStrategy` for any `T` that
29/// already implements your strategy + `Clone` + `Send` + `'static`.
30impl<T> CloneEarlyStoppingStrategy for T
31where
32    T: EarlyStoppingStrategy + Clone + Send + 'static,
33{
34    fn clone_box(&self) -> Box<dyn CloneEarlyStoppingStrategy> {
35        Box::new(self.clone())
36    }
37}
38
39/// Now you can `impl Clone` for the boxed trait object.
40impl Clone for Box<dyn CloneEarlyStoppingStrategy> {
41    fn clone(&self) -> Box<dyn CloneEarlyStoppingStrategy> {
42        self.clone_box()
43    }
44}
45
46/// An [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected
47/// during training or validation.
48#[derive(Clone)]
49pub struct MetricEarlyStoppingStrategy {
50    condition: StoppingCondition,
51    metric_name: MetricName,
52    aggregate: Aggregate,
53    direction: Direction,
54    split: Split,
55    best_epoch: usize,
56    best_value: f64,
57    warmup_epochs: Option<usize>,
58}
59
60impl EarlyStoppingStrategy for MetricEarlyStoppingStrategy {
61    fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool {
62        let current_value =
63            match store.find_metric(&self.metric_name, epoch, self.aggregate, self.split) {
64                Some(value) => value,
65                None => {
66                    log::warn!("Can't find metric for early stopping.");
67                    return false;
68                }
69            };
70
71        let is_best = match self.direction {
72            Direction::Lowest => current_value < self.best_value,
73            Direction::Highest => current_value > self.best_value,
74        };
75
76        if is_best {
77            log::info!(
78                "New best epoch found {} {}: {}",
79                epoch,
80                self.metric_name,
81                current_value
82            );
83            self.best_value = current_value;
84            self.best_epoch = epoch;
85            return false;
86        }
87
88        if let Some(warmup_epochs) = self.warmup_epochs
89            && epoch <= warmup_epochs
90        {
91            return false;
92        }
93
94        match self.condition {
95            StoppingCondition::NoImprovementSince { n_epochs } => {
96                let should_stop = epoch - self.best_epoch >= n_epochs;
97
98                if should_stop {
99                    log::info!(
100                        "Stopping training loop, no improvement since epoch {}, {}: {},  current \
101                         epoch {}, {}: {}",
102                        self.best_epoch,
103                        self.metric_name,
104                        self.best_value,
105                        epoch,
106                        self.metric_name,
107                        current_value
108                    );
109                }
110
111                should_stop
112            }
113        }
114    }
115}
116
117impl MetricEarlyStoppingStrategy {
118    /// Create a new [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected
119    /// during training or validation.
120    ///
121    /// # Notes
122    ///
123    /// The metric should be registered for early stopping to work, otherwise no data is collected.
124    pub fn new<Me: Metric>(
125        metric: &Me,
126        aggregate: Aggregate,
127        direction: Direction,
128        split: Split,
129        condition: StoppingCondition,
130    ) -> Self {
131        let init_value = match direction {
132            Direction::Lowest => f64::MAX,
133            Direction::Highest => f64::MIN,
134        };
135
136        Self {
137            metric_name: metric.name(),
138            condition,
139            aggregate,
140            direction,
141            split,
142            best_epoch: 1,
143            best_value: init_value,
144            warmup_epochs: None,
145        }
146    }
147
148    /// Get the warmup period.
149    ///
150    /// Early stopping will not trigger during the warmup epochs.
151    pub fn warmup_epochs(&self) -> Option<usize> {
152        self.warmup_epochs
153    }
154
155    /// Set the warmup epochs.
156    ///
157    /// Early stopping will not trigger during the warmup epochs.
158    ///
159    /// # Arguments
160    /// - `warmup`: the number of warmup epochs, or None.
161    pub fn with_warmup_epochs(self, warmup: Option<usize>) -> Self {
162        Self {
163            warmup_epochs: warmup,
164            ..self
165        }
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use std::sync::Arc;
172
173    use crate::{
174        EventProcessorTraining, TestBackend,
175        logger::InMemoryMetricLogger,
176        metric::{
177            LossMetric,
178            processor::{
179                MetricsTraining, MinimalEventProcessor,
180                test_utils::{end_epoch, process_train},
181            },
182            store::LogEventStore,
183        },
184    };
185
186    use super::*;
187
188    #[test]
189    fn never_early_stop_while_it_is_improving() {
190        test_early_stopping(
191            None,
192            1,
193            &[
194                (&[0.5, 0.3], false, "Should not stop first epoch"),
195                (&[0.4, 0.3], false, "Should not stop when improving"),
196                (&[0.3, 0.3], false, "Should not stop when improving"),
197                (&[0.2, 0.3], false, "Should not stop when improving"),
198            ],
199        );
200    }
201
202    #[test]
203    fn early_stop_when_no_improvement_since_two_epochs() {
204        test_early_stopping(
205            None,
206            2,
207            &[
208                (&[1.0, 0.5], false, "Should not stop first epoch"),
209                (&[0.5, 0.3], false, "Should not stop when improving"),
210                (
211                    &[1.0, 3.0],
212                    false,
213                    "Should not stop first time it gets worse",
214                ),
215                (
216                    &[1.0, 2.0],
217                    true,
218                    "Should stop since two following epochs didn't improve",
219                ),
220            ],
221        );
222    }
223
224    #[test]
225    fn early_stopping_with_warmup() {
226        test_early_stopping(
227            Some(3),
228            2,
229            &[
230                (&[1.0, 0.5], false, "Should not stop during warmup"),
231                (&[1.0, 0.5], false, "Should not stop during warmup"),
232                (&[1.0, 0.5], false, "Should not stop during warmup"),
233                (
234                    &[1.0, 0.5],
235                    true,
236                    "Should stop when not improving after warmup",
237                ),
238            ],
239        )
240    }
241
242    #[test]
243    fn early_stop_when_stays_equal() {
244        test_early_stopping(
245            None,
246            2,
247            &[
248                (&[0.5, 0.3], false, "Should not stop first epoch"),
249                (
250                    &[0.5, 0.3],
251                    false,
252                    "Should not stop first time it stars the same",
253                ),
254                (
255                    &[0.5, 0.3],
256                    true,
257                    "Should stop since two following epochs didn't improve",
258                ),
259            ],
260        );
261    }
262
263    fn test_early_stopping(warmup: Option<usize>, n_epochs: usize, data: &[(&[f64], bool, &str)]) {
264        let loss = LossMetric::<TestBackend>::new();
265        let mut early_stopping = MetricEarlyStoppingStrategy::new(
266            &loss,
267            Aggregate::Mean,
268            Direction::Lowest,
269            Split::Train,
270            StoppingCondition::NoImprovementSince { n_epochs },
271        )
272        .with_warmup_epochs(warmup);
273        let mut store = LogEventStore::default();
274        let mut metrics = MetricsTraining::<f64, f64>::default();
275
276        store.register_logger(InMemoryMetricLogger::default());
277        metrics.register_train_metric_numeric(loss);
278
279        let store = Arc::new(EventStoreClient::new(store));
280        let mut processor = MinimalEventProcessor::new(metrics, store.clone());
281
282        let mut epoch = 1;
283        processor.process_train(crate::LearnerEvent::Start);
284        for (points, should_start, comment) in data {
285            for point in points.iter() {
286                process_train(&mut processor, *point, epoch);
287            }
288            end_epoch(&mut processor, epoch);
289
290            assert_eq!(
291                *should_start,
292                early_stopping.should_stop(epoch, &store),
293                "{comment}"
294            );
295            epoch += 1;
296        }
297    }
298}