entrenar/train/callback/
checkpoint.rs1use std::path::PathBuf;
4
5use super::traits::{CallbackAction, CallbackContext, TrainerCallback};
6
7#[derive(Clone, Debug)]
11pub struct CheckpointCallback {
12 checkpoint_dir: PathBuf,
14 save_every: Option<usize>,
16 save_best: bool,
18 best_loss: f32,
20 pub(crate) last_saved_epoch: Option<usize>,
22}
23
24impl CheckpointCallback {
25 pub fn new(checkpoint_dir: impl Into<PathBuf>) -> Self {
27 Self {
28 checkpoint_dir: checkpoint_dir.into(),
29 save_every: None,
30 save_best: true,
31 best_loss: f32::INFINITY,
32 last_saved_epoch: None,
33 }
34 }
35
36 pub fn save_every(mut self, epochs: usize) -> Self {
38 self.save_every = Some(epochs);
39 self
40 }
41
42 pub fn save_best(mut self, save: bool) -> Self {
44 self.save_best = save;
45 self
46 }
47
48 pub fn checkpoint_path(&self, epoch: usize) -> PathBuf {
50 self.checkpoint_dir.join(format!("checkpoint_epoch_{epoch}.json"))
51 }
52
53 pub fn best_checkpoint_path(&self) -> PathBuf {
55 self.checkpoint_dir.join("checkpoint_best.json")
56 }
57
58 fn save_checkpoint(&mut self, epoch: usize, is_best: bool) {
60 std::fs::create_dir_all(&self.checkpoint_dir).ok();
62
63 let path = if is_best { self.best_checkpoint_path() } else { self.checkpoint_path(epoch) };
65
66 let timestamp = std::time::SystemTime::now()
68 .duration_since(std::time::UNIX_EPOCH)
69 .map(|d| d.as_secs())
70 .unwrap_or(0);
71 let info =
72 format!(r#"{{"epoch": {epoch}, "is_best": {is_best}, "timestamp": {timestamp}}}"#);
73 std::fs::write(&path, info).ok();
74
75 self.last_saved_epoch = Some(epoch);
76 }
77}
78
79impl TrainerCallback for CheckpointCallback {
80 fn on_epoch_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
81 let mut should_save = false;
82 let mut is_best = false;
83
84 if let Some(interval) = self.save_every {
86 if (ctx.epoch + 1).is_multiple_of(interval) {
87 should_save = true;
88 }
89 }
90
91 let loss = ctx.val_loss.unwrap_or(ctx.loss);
93 if self.save_best && loss < self.best_loss {
94 self.best_loss = loss;
95 should_save = true;
96 is_best = true;
97 }
98
99 if should_save {
100 self.save_checkpoint(ctx.epoch, is_best);
101 }
102
103 CallbackAction::Continue
104 }
105
106 fn on_train_end(&mut self, ctx: &CallbackContext) {
107 self.save_checkpoint(ctx.epoch, false);
109 }
110
111 fn name(&self) -> &'static str {
112 "CheckpointCallback"
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119
120 #[test]
121 fn test_checkpoint_callback_paths() {
122 let cb = CheckpointCallback::new("/tmp/checkpoints");
123 assert_eq!(
124 cb.checkpoint_path(5),
125 PathBuf::from("/tmp/checkpoints/checkpoint_epoch_5.json")
126 );
127 assert_eq!(
128 cb.best_checkpoint_path(),
129 PathBuf::from("/tmp/checkpoints/checkpoint_best.json")
130 );
131 }
132
133 #[test]
134 fn test_checkpoint_callback_save_every() {
135 let temp_dir = tempfile::tempdir().expect("temp file creation should succeed");
136 let mut cb = CheckpointCallback::new(temp_dir.path()).save_every(2);
137
138 let mut ctx = CallbackContext::default();
139 ctx.loss = 1.0;
140 cb.on_epoch_end(&ctx);
141
142 ctx.epoch = 1;
143 cb.on_epoch_end(&ctx);
144 assert_eq!(cb.last_saved_epoch, Some(1));
145 }
146
147 #[test]
148 fn test_checkpoint_callback_save_best_disabled() {
149 let temp_dir = tempfile::tempdir().expect("temp file creation should succeed");
150 let mut cb = CheckpointCallback::new(temp_dir.path()).save_best(false);
151
152 let mut ctx = CallbackContext::default();
153 ctx.loss = 0.1;
154 cb.on_epoch_end(&ctx);
155 assert!(cb.last_saved_epoch.is_none());
156 }
157
158 #[test]
159 fn test_checkpoint_callback_on_train_end() {
160 let temp_dir = tempfile::tempdir().expect("temp file creation should succeed");
161 let mut cb = CheckpointCallback::new(temp_dir.path());
162
163 let ctx = CallbackContext { epoch: 5, ..Default::default() };
164
165 cb.on_train_end(&ctx);
166 assert_eq!(cb.last_saved_epoch, Some(5));
167 }
168
169 #[test]
170 fn test_checkpoint_callback_name() {
171 let cb = CheckpointCallback::new("/tmp");
172 assert_eq!(cb.name(), "CheckpointCallback");
173 }
174
175 #[test]
176 fn test_checkpoint_callback_val_loss_for_best() {
177 let temp_dir = tempfile::tempdir().expect("temp file creation should succeed");
178 let mut cb = CheckpointCallback::new(temp_dir.path());
179
180 let mut ctx = CallbackContext::default();
181 ctx.loss = 1.0;
182 ctx.val_loss = Some(0.5);
183 cb.on_epoch_end(&ctx);
184 assert_eq!(cb.best_loss, 0.5);
185 }
186
187 #[test]
188 fn test_checkpoint_callback_clone() {
189 let cb = CheckpointCallback::new("/tmp/test");
190 let cloned = cb.clone();
191 assert_eq!(cb.checkpoint_dir, cloned.checkpoint_dir);
192 }
193}
194
195#[cfg(test)]
196mod proptests {
197 use super::*;
198 use proptest::prelude::*;
199
200 proptest! {
201 #[test]
203 fn checkpoint_paths_are_consistent(
204 epoch in 0usize..1000,
205 ) {
206 let cb = CheckpointCallback::new("/tmp/test");
207
208 let path = cb.checkpoint_path(epoch);
210 let expected = format!("/tmp/test/checkpoint_epoch_{epoch}.json");
211 prop_assert_eq!(path, PathBuf::from(&expected));
212
213 let best = cb.best_checkpoint_path();
215 prop_assert_eq!(best, PathBuf::from("/tmp/test/checkpoint_best.json"));
216 }
217 }
218}