Skip to main content

entrenar/train/callback/
early_stopping.rs

1//! Early stopping callback to halt training when loss plateaus
2
3use super::traits::{CallbackAction, CallbackContext, TrainerCallback};
4
5/// Early stopping callback to halt training when loss plateaus
6///
7/// Monitors a metric and stops training if no improvement is seen
8/// for `patience` epochs.
9///
10/// # Example
11///
12/// ```rust
13/// use entrenar::train::callback::EarlyStopping;
14///
15/// // Stop if no improvement for 5 epochs, min improvement 0.001
16/// let early_stop = EarlyStopping::new(5, 0.001);
17/// ```
18#[derive(Clone, Debug)]
19pub struct EarlyStopping {
20    /// Number of epochs to wait for improvement
21    patience: usize,
22    /// Minimum improvement to reset patience
23    min_delta: f32,
24    /// Best loss seen so far
25    best_loss: f32,
26    /// Epochs without improvement
27    pub(crate) epochs_without_improvement: usize,
28    /// Whether to restore best weights (placeholder)
29    pub(crate) restore_best: bool,
30    /// Monitor validation loss instead of training loss
31    monitor_val: bool,
32}
33
34impl EarlyStopping {
35    /// Create new early stopping callback
36    pub fn new(patience: usize, min_delta: f32) -> Self {
37        Self {
38            patience,
39            min_delta,
40            best_loss: f32::INFINITY,
41            epochs_without_improvement: 0,
42            restore_best: false,
43            monitor_val: false,
44        }
45    }
46
47    /// Configure to restore best weights on stop
48    pub fn with_restore_best(mut self) -> Self {
49        self.restore_best = true;
50        self
51    }
52
53    /// Configure to monitor validation loss (requires validation data)
54    ///
55    /// When enabled, early stopping will only consider validation loss.
56    /// If validation loss is not available, training loss is used as fallback.
57    pub fn monitor_validation(mut self) -> Self {
58        self.monitor_val = true;
59        self
60    }
61
62    /// Reset internal state
63    pub fn reset(&mut self) {
64        self.best_loss = f32::INFINITY;
65        self.epochs_without_improvement = 0;
66    }
67
68    /// Check if loss improved
69    fn check_improvement(&mut self, loss: f32) -> bool {
70        if loss < self.best_loss - self.min_delta {
71            self.best_loss = loss;
72            self.epochs_without_improvement = 0;
73            true
74        } else {
75            self.epochs_without_improvement += 1;
76            false
77        }
78    }
79}
80
81impl TrainerCallback for EarlyStopping {
82    fn on_epoch_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
83        // Use val_loss if monitoring validation (with fallback), otherwise use training loss
84        let loss = if self.monitor_val { ctx.val_loss.unwrap_or(ctx.loss) } else { ctx.loss };
85        self.check_improvement(loss);
86
87        if self.epochs_without_improvement >= self.patience {
88            eprintln!(
89                "Early stopping: no improvement for {} epochs (best loss: {:.4})",
90                self.patience, self.best_loss
91            );
92            CallbackAction::Stop
93        } else {
94            CallbackAction::Continue
95        }
96    }
97
98    fn name(&self) -> &'static str {
99        "EarlyStopping"
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn test_early_stopping_patience() {
109        let mut es = EarlyStopping::new(3, 0.001);
110        let mut ctx = CallbackContext::default();
111
112        // First epoch - establishes baseline
113        ctx.loss = 1.0;
114        assert_eq!(es.on_epoch_end(&ctx), CallbackAction::Continue);
115
116        // Improvement
117        ctx.loss = 0.9;
118        ctx.epoch = 1;
119        assert_eq!(es.on_epoch_end(&ctx), CallbackAction::Continue);
120
121        // No improvement (within delta)
122        ctx.loss = 0.899;
123        ctx.epoch = 2;
124        assert_eq!(es.on_epoch_end(&ctx), CallbackAction::Continue);
125
126        // Still no improvement
127        ctx.loss = 0.899;
128        ctx.epoch = 3;
129        assert_eq!(es.on_epoch_end(&ctx), CallbackAction::Continue);
130
131        // Still no improvement - should stop (patience=3)
132        ctx.loss = 0.899;
133        ctx.epoch = 4;
134        assert_eq!(es.on_epoch_end(&ctx), CallbackAction::Stop);
135    }
136
137    #[test]
138    fn test_early_stopping_improvement_resets() {
139        let mut es = EarlyStopping::new(2, 0.01);
140        let mut ctx = CallbackContext::default();
141
142        ctx.loss = 1.0;
143        es.on_epoch_end(&ctx);
144
145        ctx.loss = 1.0;
146        ctx.epoch = 1;
147        es.on_epoch_end(&ctx);
148
149        // Improvement resets counter
150        ctx.loss = 0.5;
151        ctx.epoch = 2;
152        assert_eq!(es.on_epoch_end(&ctx), CallbackAction::Continue);
153        assert_eq!(es.epochs_without_improvement, 0);
154    }
155
156    #[test]
157    fn test_early_stopping_with_restore_best() {
158        let es = EarlyStopping::new(3, 0.001).with_restore_best();
159        assert!(es.restore_best);
160    }
161
162    #[test]
163    fn test_early_stopping_monitor_validation() {
164        let mut es = EarlyStopping::new(3, 0.001).monitor_validation();
165        assert!(es.monitor_val);
166
167        let mut ctx = CallbackContext::default();
168        ctx.loss = 1.0;
169        ctx.val_loss = Some(0.5);
170        es.on_epoch_end(&ctx);
171        assert_eq!(es.best_loss, 0.5);
172    }
173
174    #[test]
175    fn test_early_stopping_reset() {
176        let mut es = EarlyStopping::new(3, 0.001);
177        let mut ctx = CallbackContext::default();
178        ctx.loss = 0.5;
179        es.on_epoch_end(&ctx);
180        assert_eq!(es.best_loss, 0.5);
181
182        es.reset();
183        assert_eq!(es.best_loss, f32::INFINITY);
184        assert_eq!(es.epochs_without_improvement, 0);
185    }
186
187    #[test]
188    fn test_early_stopping_name() {
189        let es = EarlyStopping::new(3, 0.001);
190        assert_eq!(es.name(), "EarlyStopping");
191    }
192
193    #[test]
194    fn test_early_stopping_clone() {
195        let es = EarlyStopping::new(5, 0.01);
196        let cloned = es.clone();
197        assert_eq!(es.patience, cloned.patience);
198    }
199}
200
201#[cfg(test)]
202mod proptests {
203    use super::*;
204    use proptest::prelude::*;
205
206    proptest! {
207        /// Early stopping should always stop after patience epochs without improvement
208        #[test]
209        fn early_stopping_respects_patience(
210            patience in 1usize..10,
211            min_delta in 0.0001f32..0.1,
212            initial_loss in 0.1f32..10.0,
213        ) {
214            let mut es = EarlyStopping::new(patience, min_delta);
215            let mut ctx = CallbackContext::default();
216
217            // First epoch establishes baseline
218            ctx.loss = initial_loss;
219            es.on_epoch_end(&ctx);
220
221            // Run for patience + 1 epochs without improvement
222            for epoch in 1..=patience {
223                ctx.epoch = epoch;
224                ctx.loss = initial_loss; // No improvement
225                let action = es.on_epoch_end(&ctx);
226
227                if epoch < patience {
228                    prop_assert_eq!(action, CallbackAction::Continue);
229                } else {
230                    prop_assert_eq!(action, CallbackAction::Stop);
231                }
232            }
233        }
234
235        /// Early stopping counter should reset on improvement
236        #[test]
237        fn early_stopping_resets_on_improvement(
238            patience in 2usize..10,
239            min_delta in 0.001f32..0.1,
240            initial_loss in 1.0f32..10.0,
241            improvement in 0.2f32..0.5,
242        ) {
243            let mut es = EarlyStopping::new(patience, min_delta);
244            let mut ctx = CallbackContext::default();
245
246            // Establish baseline
247            ctx.loss = initial_loss;
248            es.on_epoch_end(&ctx);
249
250            // One epoch without improvement
251            ctx.epoch = 1;
252            es.on_epoch_end(&ctx);
253            prop_assert!(es.epochs_without_improvement >= 1);
254
255            // Improvement resets counter
256            ctx.epoch = 2;
257            ctx.loss = initial_loss - improvement;
258            es.on_epoch_end(&ctx);
259            prop_assert_eq!(es.epochs_without_improvement, 0);
260        }
261    }
262}