entrenar/train/callback/
early_stopping.rs1use super::traits::{CallbackAction, CallbackContext, TrainerCallback};
4
5#[derive(Clone, Debug)]
19pub struct EarlyStopping {
20 patience: usize,
22 min_delta: f32,
24 best_loss: f32,
26 pub(crate) epochs_without_improvement: usize,
28 pub(crate) restore_best: bool,
30 monitor_val: bool,
32}
33
34impl EarlyStopping {
35 pub fn new(patience: usize, min_delta: f32) -> Self {
37 Self {
38 patience,
39 min_delta,
40 best_loss: f32::INFINITY,
41 epochs_without_improvement: 0,
42 restore_best: false,
43 monitor_val: false,
44 }
45 }
46
47 pub fn with_restore_best(mut self) -> Self {
49 self.restore_best = true;
50 self
51 }
52
53 pub fn monitor_validation(mut self) -> Self {
58 self.monitor_val = true;
59 self
60 }
61
62 pub fn reset(&mut self) {
64 self.best_loss = f32::INFINITY;
65 self.epochs_without_improvement = 0;
66 }
67
68 fn check_improvement(&mut self, loss: f32) -> bool {
70 if loss < self.best_loss - self.min_delta {
71 self.best_loss = loss;
72 self.epochs_without_improvement = 0;
73 true
74 } else {
75 self.epochs_without_improvement += 1;
76 false
77 }
78 }
79}
80
81impl TrainerCallback for EarlyStopping {
82 fn on_epoch_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
83 let loss = if self.monitor_val { ctx.val_loss.unwrap_or(ctx.loss) } else { ctx.loss };
85 self.check_improvement(loss);
86
87 if self.epochs_without_improvement >= self.patience {
88 eprintln!(
89 "Early stopping: no improvement for {} epochs (best loss: {:.4})",
90 self.patience, self.best_loss
91 );
92 CallbackAction::Stop
93 } else {
94 CallbackAction::Continue
95 }
96 }
97
98 fn name(&self) -> &'static str {
99 "EarlyStopping"
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 #[test]
108 fn test_early_stopping_patience() {
109 let mut es = EarlyStopping::new(3, 0.001);
110 let mut ctx = CallbackContext::default();
111
112 ctx.loss = 1.0;
114 assert_eq!(es.on_epoch_end(&ctx), CallbackAction::Continue);
115
116 ctx.loss = 0.9;
118 ctx.epoch = 1;
119 assert_eq!(es.on_epoch_end(&ctx), CallbackAction::Continue);
120
121 ctx.loss = 0.899;
123 ctx.epoch = 2;
124 assert_eq!(es.on_epoch_end(&ctx), CallbackAction::Continue);
125
126 ctx.loss = 0.899;
128 ctx.epoch = 3;
129 assert_eq!(es.on_epoch_end(&ctx), CallbackAction::Continue);
130
131 ctx.loss = 0.899;
133 ctx.epoch = 4;
134 assert_eq!(es.on_epoch_end(&ctx), CallbackAction::Stop);
135 }
136
137 #[test]
138 fn test_early_stopping_improvement_resets() {
139 let mut es = EarlyStopping::new(2, 0.01);
140 let mut ctx = CallbackContext::default();
141
142 ctx.loss = 1.0;
143 es.on_epoch_end(&ctx);
144
145 ctx.loss = 1.0;
146 ctx.epoch = 1;
147 es.on_epoch_end(&ctx);
148
149 ctx.loss = 0.5;
151 ctx.epoch = 2;
152 assert_eq!(es.on_epoch_end(&ctx), CallbackAction::Continue);
153 assert_eq!(es.epochs_without_improvement, 0);
154 }
155
156 #[test]
157 fn test_early_stopping_with_restore_best() {
158 let es = EarlyStopping::new(3, 0.001).with_restore_best();
159 assert!(es.restore_best);
160 }
161
162 #[test]
163 fn test_early_stopping_monitor_validation() {
164 let mut es = EarlyStopping::new(3, 0.001).monitor_validation();
165 assert!(es.monitor_val);
166
167 let mut ctx = CallbackContext::default();
168 ctx.loss = 1.0;
169 ctx.val_loss = Some(0.5);
170 es.on_epoch_end(&ctx);
171 assert_eq!(es.best_loss, 0.5);
172 }
173
174 #[test]
175 fn test_early_stopping_reset() {
176 let mut es = EarlyStopping::new(3, 0.001);
177 let mut ctx = CallbackContext::default();
178 ctx.loss = 0.5;
179 es.on_epoch_end(&ctx);
180 assert_eq!(es.best_loss, 0.5);
181
182 es.reset();
183 assert_eq!(es.best_loss, f32::INFINITY);
184 assert_eq!(es.epochs_without_improvement, 0);
185 }
186
187 #[test]
188 fn test_early_stopping_name() {
189 let es = EarlyStopping::new(3, 0.001);
190 assert_eq!(es.name(), "EarlyStopping");
191 }
192
193 #[test]
194 fn test_early_stopping_clone() {
195 let es = EarlyStopping::new(5, 0.01);
196 let cloned = es.clone();
197 assert_eq!(es.patience, cloned.patience);
198 }
199}
200
201#[cfg(test)]
202mod proptests {
203 use super::*;
204 use proptest::prelude::*;
205
206 proptest! {
207 #[test]
209 fn early_stopping_respects_patience(
210 patience in 1usize..10,
211 min_delta in 0.0001f32..0.1,
212 initial_loss in 0.1f32..10.0,
213 ) {
214 let mut es = EarlyStopping::new(patience, min_delta);
215 let mut ctx = CallbackContext::default();
216
217 ctx.loss = initial_loss;
219 es.on_epoch_end(&ctx);
220
221 for epoch in 1..=patience {
223 ctx.epoch = epoch;
224 ctx.loss = initial_loss; let action = es.on_epoch_end(&ctx);
226
227 if epoch < patience {
228 prop_assert_eq!(action, CallbackAction::Continue);
229 } else {
230 prop_assert_eq!(action, CallbackAction::Stop);
231 }
232 }
233 }
234
235 #[test]
237 fn early_stopping_resets_on_improvement(
238 patience in 2usize..10,
239 min_delta in 0.001f32..0.1,
240 initial_loss in 1.0f32..10.0,
241 improvement in 0.2f32..0.5,
242 ) {
243 let mut es = EarlyStopping::new(patience, min_delta);
244 let mut ctx = CallbackContext::default();
245
246 ctx.loss = initial_loss;
248 es.on_epoch_end(&ctx);
249
250 ctx.epoch = 1;
252 es.on_epoch_end(&ctx);
253 prop_assert!(es.epochs_without_improvement >= 1);
254
255 ctx.epoch = 2;
257 ctx.loss = initial_loss - improvement;
258 es.on_epoch_end(&ctx);
259 prop_assert_eq!(es.epochs_without_improvement, 0);
260 }
261 }
262}