burn_train/learner/
early_stopping.rs

1use crate::metric::{
2    store::{Aggregate, Direction, EventStoreClient, Split},
3    Metric,
4};
5
6/// The condition that [early stopping strategies](EarlyStoppingStrategy) should follow.
7pub enum StoppingCondition {
8    /// When no improvement has happened since the given number of epochs.
9    NoImprovementSince {
10        /// The number of epochs allowed to worsen before it gets better.
11        n_epochs: usize,
12    },
13}
14
15/// A strategy that checks if the training should be stopped.
16pub trait EarlyStoppingStrategy {
17    /// Update its current state and returns if the training should be stopped.
18    fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool;
19}
20
21/// An [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected
22/// during training or validation.
23pub struct MetricEarlyStoppingStrategy {
24    condition: StoppingCondition,
25    metric_name: String,
26    aggregate: Aggregate,
27    direction: Direction,
28    split: Split,
29    best_epoch: usize,
30    best_value: f64,
31}
32
33impl EarlyStoppingStrategy for MetricEarlyStoppingStrategy {
34    fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool {
35        let current_value =
36            match store.find_metric(&self.metric_name, epoch, self.aggregate, self.split) {
37                Some(value) => value,
38                None => {
39                    log::warn!("Can't find metric for early stopping.");
40                    return false;
41                }
42            };
43
44        let is_best = match self.direction {
45            Direction::Lowest => current_value < self.best_value,
46            Direction::Highest => current_value > self.best_value,
47        };
48
49        if is_best {
50            log::info!(
51                "New best epoch found {} {}: {}",
52                epoch,
53                self.metric_name,
54                current_value
55            );
56            self.best_value = current_value;
57            self.best_epoch = epoch;
58            return false;
59        }
60
61        match self.condition {
62            StoppingCondition::NoImprovementSince { n_epochs } => {
63                let should_stop = epoch - self.best_epoch >= n_epochs;
64
65                if should_stop {
66                    log::info!(
67                        "Stopping training loop, no improvement since epoch {}, {}: {},  current \
68                         epoch {}, {}: {}",
69                        self.best_epoch,
70                        self.metric_name,
71                        self.best_value,
72                        epoch,
73                        self.metric_name,
74                        current_value
75                    );
76                }
77
78                should_stop
79            }
80        }
81    }
82}
83
84impl MetricEarlyStoppingStrategy {
85    /// Create a new [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected
86    /// during training or validation.
87    ///
88    /// # Notes
89    ///
90    /// The metric should be registered for early stopping to work, otherwise no data is collected.
91    pub fn new<Me: Metric>(
92        aggregate: Aggregate,
93        direction: Direction,
94        split: Split,
95        condition: StoppingCondition,
96    ) -> Self {
97        let init_value = match direction {
98            Direction::Lowest => f64::MAX,
99            Direction::Highest => f64::MIN,
100        };
101
102        Self {
103            metric_name: Me::NAME.to_string(),
104            condition,
105            aggregate,
106            direction,
107            split,
108            best_epoch: 1,
109            best_value: init_value,
110        }
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use std::sync::Arc;
117
118    use crate::{
119        logger::InMemoryMetricLogger,
120        metric::{
121            processor::{
122                test_utils::{end_epoch, process_train},
123                Metrics, MinimalEventProcessor,
124            },
125            store::LogEventStore,
126            LossMetric,
127        },
128        TestBackend,
129    };
130
131    use super::*;
132
133    #[test]
134    fn never_early_stop_while_it_is_improving() {
135        test_early_stopping(
136            1,
137            &[
138                (&[0.5, 0.3], false, "Should not stop first epoch"),
139                (&[0.4, 0.3], false, "Should not stop when improving"),
140                (&[0.3, 0.3], false, "Should not stop when improving"),
141                (&[0.2, 0.3], false, "Should not stop when improving"),
142            ],
143        );
144    }
145
146    #[test]
147    fn early_stop_when_no_improvement_since_two_epochs() {
148        test_early_stopping(
149            2,
150            &[
151                (&[1.0, 0.5], false, "Should not stop first epoch"),
152                (&[0.5, 0.3], false, "Should not stop when improving"),
153                (
154                    &[1.0, 3.0],
155                    false,
156                    "Should not stop first time it gets worse",
157                ),
158                (
159                    &[1.0, 2.0],
160                    true,
161                    "Should stop since two following epochs didn't improve",
162                ),
163            ],
164        );
165    }
166
167    #[test]
168    fn early_stop_when_stays_equal() {
169        test_early_stopping(
170            2,
171            &[
172                (&[0.5, 0.3], false, "Should not stop first epoch"),
173                (
174                    &[0.5, 0.3],
175                    false,
176                    "Should not stop first time it stars the same",
177                ),
178                (
179                    &[0.5, 0.3],
180                    true,
181                    "Should stop since two following epochs didn't improve",
182                ),
183            ],
184        );
185    }
186
187    fn test_early_stopping(n_epochs: usize, data: &[(&[f64], bool, &str)]) {
188        let mut early_stopping = MetricEarlyStoppingStrategy::new::<LossMetric<TestBackend>>(
189            Aggregate::Mean,
190            Direction::Lowest,
191            Split::Train,
192            StoppingCondition::NoImprovementSince { n_epochs },
193        );
194        let mut store = LogEventStore::default();
195        let mut metrics = Metrics::<f64, f64>::default();
196
197        store.register_logger_train(InMemoryMetricLogger::default());
198        metrics.register_train_metric_numeric(LossMetric::<TestBackend>::new());
199
200        let store = Arc::new(EventStoreClient::new(store));
201        let mut processor = MinimalEventProcessor::new(metrics, store.clone());
202
203        let mut epoch = 1;
204        for (points, should_start, comment) in data {
205            for point in points.iter() {
206                process_train(&mut processor, *point, epoch);
207            }
208            end_epoch(&mut processor, epoch);
209
210            assert_eq!(
211                *should_start,
212                early_stopping.should_stop(epoch, &store),
213                "{comment}"
214            );
215            epoch += 1;
216        }
217    }
218}