burn_train/learner/
early_stopping.rs

1use crate::metric::{
2    Metric, MetricName,
3    store::{Aggregate, Direction, EventStoreClient, Split},
4};
5
6/// The condition that [early stopping strategies](EarlyStoppingStrategy) should follow.
7#[derive(Clone)]
8pub enum StoppingCondition {
9    /// When no improvement has happened since the given number of epochs.
10    NoImprovementSince {
11        /// The number of epochs allowed to worsen before it gets better.
12        n_epochs: usize,
13    },
14}
15
16/// A strategy that checks if the training should be stopped.
17pub trait EarlyStoppingStrategy: Send {
18    /// Update its current state and returns if the training should be stopped.
19    fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool;
20}
21
22/// A helper trait to provide type-erased cloning.
23pub trait CloneEarlyStoppingStrategy: EarlyStoppingStrategy + Send {
24    /// Clone into a boxed trait object.
25    fn clone_box(&self) -> Box<dyn CloneEarlyStoppingStrategy>;
26}
27
28/// Blanket-implement `CloneEarlyStoppingStrategy` for any `T` that
29/// already implements your strategy + `Clone` + `Send` + `'static`.
30impl<T> CloneEarlyStoppingStrategy for T
31where
32    T: EarlyStoppingStrategy + Clone + Send + 'static,
33{
34    fn clone_box(&self) -> Box<dyn CloneEarlyStoppingStrategy> {
35        Box::new(self.clone())
36    }
37}
38
39/// Now you can `impl Clone` for the boxed trait object.
40impl Clone for Box<dyn CloneEarlyStoppingStrategy> {
41    fn clone(&self) -> Box<dyn CloneEarlyStoppingStrategy> {
42        self.clone_box()
43    }
44}
45
46/// An [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected
47/// during training or validation.
48#[derive(Clone)]
49pub struct MetricEarlyStoppingStrategy {
50    condition: StoppingCondition,
51    metric_name: MetricName,
52    aggregate: Aggregate,
53    direction: Direction,
54    split: Split,
55    best_epoch: usize,
56    best_value: f64,
57}
58
59impl EarlyStoppingStrategy for MetricEarlyStoppingStrategy {
60    fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool {
61        let current_value =
62            match store.find_metric(&self.metric_name, epoch, self.aggregate, self.split) {
63                Some(value) => value,
64                None => {
65                    log::warn!("Can't find metric for early stopping.");
66                    return false;
67                }
68            };
69
70        let is_best = match self.direction {
71            Direction::Lowest => current_value < self.best_value,
72            Direction::Highest => current_value > self.best_value,
73        };
74
75        if is_best {
76            log::info!(
77                "New best epoch found {} {}: {}",
78                epoch,
79                self.metric_name,
80                current_value
81            );
82            self.best_value = current_value;
83            self.best_epoch = epoch;
84            return false;
85        }
86
87        match self.condition {
88            StoppingCondition::NoImprovementSince { n_epochs } => {
89                let should_stop = epoch - self.best_epoch >= n_epochs;
90
91                if should_stop {
92                    log::info!(
93                        "Stopping training loop, no improvement since epoch {}, {}: {},  current \
94                         epoch {}, {}: {}",
95                        self.best_epoch,
96                        self.metric_name,
97                        self.best_value,
98                        epoch,
99                        self.metric_name,
100                        current_value
101                    );
102                }
103
104                should_stop
105            }
106        }
107    }
108}
109
110impl MetricEarlyStoppingStrategy {
111    /// Create a new [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected
112    /// during training or validation.
113    ///
114    /// # Notes
115    ///
116    /// The metric should be registered for early stopping to work, otherwise no data is collected.
117    pub fn new<Me: Metric>(
118        metric: &Me,
119        aggregate: Aggregate,
120        direction: Direction,
121        split: Split,
122        condition: StoppingCondition,
123    ) -> Self {
124        let init_value = match direction {
125            Direction::Lowest => f64::MAX,
126            Direction::Highest => f64::MIN,
127        };
128
129        Self {
130            metric_name: metric.name(),
131            condition,
132            aggregate,
133            direction,
134            split,
135            best_epoch: 1,
136            best_value: init_value,
137        }
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use std::sync::Arc;
144
145    use crate::{
146        TestBackend,
147        logger::InMemoryMetricLogger,
148        metric::{
149            LossMetric,
150            processor::{
151                MetricsTraining, MinimalEventProcessor,
152                test_utils::{end_epoch, process_train},
153            },
154            store::LogEventStore,
155        },
156    };
157
158    use super::*;
159
160    #[test]
161    fn never_early_stop_while_it_is_improving() {
162        test_early_stopping(
163            1,
164            &[
165                (&[0.5, 0.3], false, "Should not stop first epoch"),
166                (&[0.4, 0.3], false, "Should not stop when improving"),
167                (&[0.3, 0.3], false, "Should not stop when improving"),
168                (&[0.2, 0.3], false, "Should not stop when improving"),
169            ],
170        );
171    }
172
173    #[test]
174    fn early_stop_when_no_improvement_since_two_epochs() {
175        test_early_stopping(
176            2,
177            &[
178                (&[1.0, 0.5], false, "Should not stop first epoch"),
179                (&[0.5, 0.3], false, "Should not stop when improving"),
180                (
181                    &[1.0, 3.0],
182                    false,
183                    "Should not stop first time it gets worse",
184                ),
185                (
186                    &[1.0, 2.0],
187                    true,
188                    "Should stop since two following epochs didn't improve",
189                ),
190            ],
191        );
192    }
193
194    #[test]
195    fn early_stop_when_stays_equal() {
196        test_early_stopping(
197            2,
198            &[
199                (&[0.5, 0.3], false, "Should not stop first epoch"),
200                (
201                    &[0.5, 0.3],
202                    false,
203                    "Should not stop first time it stars the same",
204                ),
205                (
206                    &[0.5, 0.3],
207                    true,
208                    "Should stop since two following epochs didn't improve",
209                ),
210            ],
211        );
212    }
213
214    fn test_early_stopping(n_epochs: usize, data: &[(&[f64], bool, &str)]) {
215        let loss = LossMetric::<TestBackend>::new();
216        let mut early_stopping = MetricEarlyStoppingStrategy::new(
217            &loss,
218            Aggregate::Mean,
219            Direction::Lowest,
220            Split::Train,
221            StoppingCondition::NoImprovementSince { n_epochs },
222        );
223        let mut store = LogEventStore::default();
224        let mut metrics = MetricsTraining::<f64, f64>::default();
225
226        store.register_logger_train(InMemoryMetricLogger::default());
227        metrics.register_train_metric_numeric(loss);
228
229        let store = Arc::new(EventStoreClient::new(store));
230        let mut processor = MinimalEventProcessor::new(metrics, store.clone());
231
232        let mut epoch = 1;
233        for (points, should_start, comment) in data {
234            for point in points.iter() {
235                process_train(&mut processor, *point, epoch);
236            }
237            end_epoch(&mut processor, epoch);
238
239            assert_eq!(
240                *should_start,
241                early_stopping.should_stop(epoch, &store),
242                "{comment}"
243            );
244            epoch += 1;
245        }
246    }
247}