burn_train/learner/
early_stopping.rs

1use crate::metric::{
2    Metric,
3    store::{Aggregate, Direction, EventStoreClient, Split},
4};
5
6/// The condition that [early stopping strategies](EarlyStoppingStrategy) should follow.
7pub enum StoppingCondition {
8    /// When no improvement has happened since the given number of epochs.
9    NoImprovementSince {
10        /// The number of epochs allowed to worsen before it gets better.
11        n_epochs: usize,
12    },
13}
14
15/// A strategy that checks if the training should be stopped.
16pub trait EarlyStoppingStrategy {
17    /// Update its current state and returns if the training should be stopped.
18    fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool;
19}
20
21/// An [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected
22/// during training or validation.
23pub struct MetricEarlyStoppingStrategy {
24    condition: StoppingCondition,
25    metric_name: String,
26    aggregate: Aggregate,
27    direction: Direction,
28    split: Split,
29    best_epoch: usize,
30    best_value: f64,
31}
32
33impl EarlyStoppingStrategy for MetricEarlyStoppingStrategy {
34    fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool {
35        let current_value =
36            match store.find_metric(&self.metric_name, epoch, self.aggregate, self.split) {
37                Some(value) => value,
38                None => {
39                    log::warn!("Can't find metric for early stopping.");
40                    return false;
41                }
42            };
43
44        let is_best = match self.direction {
45            Direction::Lowest => current_value < self.best_value,
46            Direction::Highest => current_value > self.best_value,
47        };
48
49        if is_best {
50            log::info!(
51                "New best epoch found {} {}: {}",
52                epoch,
53                self.metric_name,
54                current_value
55            );
56            self.best_value = current_value;
57            self.best_epoch = epoch;
58            return false;
59        }
60
61        match self.condition {
62            StoppingCondition::NoImprovementSince { n_epochs } => {
63                let should_stop = epoch - self.best_epoch >= n_epochs;
64
65                if should_stop {
66                    log::info!(
67                        "Stopping training loop, no improvement since epoch {}, {}: {},  current \
68                         epoch {}, {}: {}",
69                        self.best_epoch,
70                        self.metric_name,
71                        self.best_value,
72                        epoch,
73                        self.metric_name,
74                        current_value
75                    );
76                }
77
78                should_stop
79            }
80        }
81    }
82}
83
84impl MetricEarlyStoppingStrategy {
85    /// Create a new [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected
86    /// during training or validation.
87    ///
88    /// # Notes
89    ///
90    /// The metric should be registered for early stopping to work, otherwise no data is collected.
91    pub fn new<Me: Metric>(
92        metric: &Me,
93        aggregate: Aggregate,
94        direction: Direction,
95        split: Split,
96        condition: StoppingCondition,
97    ) -> Self {
98        let init_value = match direction {
99            Direction::Lowest => f64::MAX,
100            Direction::Highest => f64::MIN,
101        };
102
103        Self {
104            metric_name: metric.name(),
105            condition,
106            aggregate,
107            direction,
108            split,
109            best_epoch: 1,
110            best_value: init_value,
111        }
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use std::sync::Arc;
118
119    use crate::{
120        TestBackend,
121        logger::InMemoryMetricLogger,
122        metric::{
123            LossMetric,
124            processor::{
125                Metrics, MinimalEventProcessor,
126                test_utils::{end_epoch, process_train},
127            },
128            store::LogEventStore,
129        },
130    };
131
132    use super::*;
133
134    #[test]
135    fn never_early_stop_while_it_is_improving() {
136        test_early_stopping(
137            1,
138            &[
139                (&[0.5, 0.3], false, "Should not stop first epoch"),
140                (&[0.4, 0.3], false, "Should not stop when improving"),
141                (&[0.3, 0.3], false, "Should not stop when improving"),
142                (&[0.2, 0.3], false, "Should not stop when improving"),
143            ],
144        );
145    }
146
147    #[test]
148    fn early_stop_when_no_improvement_since_two_epochs() {
149        test_early_stopping(
150            2,
151            &[
152                (&[1.0, 0.5], false, "Should not stop first epoch"),
153                (&[0.5, 0.3], false, "Should not stop when improving"),
154                (
155                    &[1.0, 3.0],
156                    false,
157                    "Should not stop first time it gets worse",
158                ),
159                (
160                    &[1.0, 2.0],
161                    true,
162                    "Should stop since two following epochs didn't improve",
163                ),
164            ],
165        );
166    }
167
168    #[test]
169    fn early_stop_when_stays_equal() {
170        test_early_stopping(
171            2,
172            &[
173                (&[0.5, 0.3], false, "Should not stop first epoch"),
174                (
175                    &[0.5, 0.3],
176                    false,
177                    "Should not stop first time it stars the same",
178                ),
179                (
180                    &[0.5, 0.3],
181                    true,
182                    "Should stop since two following epochs didn't improve",
183                ),
184            ],
185        );
186    }
187
188    fn test_early_stopping(n_epochs: usize, data: &[(&[f64], bool, &str)]) {
189        let loss = LossMetric::<TestBackend>::new();
190        let mut early_stopping = MetricEarlyStoppingStrategy::new(
191            &loss,
192            Aggregate::Mean,
193            Direction::Lowest,
194            Split::Train,
195            StoppingCondition::NoImprovementSince { n_epochs },
196        );
197        let mut store = LogEventStore::default();
198        let mut metrics = Metrics::<f64, f64>::default();
199
200        store.register_logger_train(InMemoryMetricLogger::default());
201        metrics.register_train_metric_numeric(loss);
202
203        let store = Arc::new(EventStoreClient::new(store));
204        let mut processor = MinimalEventProcessor::new(metrics, store.clone());
205
206        let mut epoch = 1;
207        for (points, should_start, comment) in data {
208            for point in points.iter() {
209                process_train(&mut processor, *point, epoch);
210            }
211            end_epoch(&mut processor, epoch);
212
213            assert_eq!(
214                *should_start,
215                early_stopping.should_stop(epoch, &store),
216                "{comment}"
217            );
218            epoch += 1;
219        }
220    }
221}