Skip to main content

entrenar/train/callback/
checkpoint.rs

1//! Checkpoint callback for saving model state periodically
2
3use std::path::PathBuf;
4
5use super::traits::{CallbackAction, CallbackContext, TrainerCallback};
6
7/// Checkpoint callback to save model state periodically
8///
9/// Saves model state every N epochs or when a new best loss is achieved.
10#[derive(Clone, Debug)]
11pub struct CheckpointCallback {
12    /// Directory to save checkpoints
13    checkpoint_dir: PathBuf,
14    /// Save every N epochs (None = only save best)
15    save_every: Option<usize>,
16    /// Save on best loss
17    save_best: bool,
18    /// Best loss seen
19    best_loss: f32,
20    /// Last saved epoch
21    pub(crate) last_saved_epoch: Option<usize>,
22}
23
24impl CheckpointCallback {
25    /// Create checkpoint callback saving to directory
26    pub fn new(checkpoint_dir: impl Into<PathBuf>) -> Self {
27        Self {
28            checkpoint_dir: checkpoint_dir.into(),
29            save_every: None,
30            save_best: true,
31            best_loss: f32::INFINITY,
32            last_saved_epoch: None,
33        }
34    }
35
36    /// Configure to save every N epochs
37    pub fn save_every(mut self, epochs: usize) -> Self {
38        self.save_every = Some(epochs);
39        self
40    }
41
42    /// Configure to save on best loss
43    pub fn save_best(mut self, save: bool) -> Self {
44        self.save_best = save;
45        self
46    }
47
48    /// Get checkpoint path for epoch
49    pub fn checkpoint_path(&self, epoch: usize) -> PathBuf {
50        self.checkpoint_dir.join(format!("checkpoint_epoch_{epoch}.json"))
51    }
52
53    /// Get best checkpoint path
54    pub fn best_checkpoint_path(&self) -> PathBuf {
55        self.checkpoint_dir.join("checkpoint_best.json")
56    }
57
58    /// Save checkpoint (placeholder - actual implementation needs model access)
59    fn save_checkpoint(&mut self, epoch: usize, is_best: bool) {
60        // Ensure directory exists
61        std::fs::create_dir_all(&self.checkpoint_dir).ok();
62
63        // Placeholder: In real implementation, would serialize model state
64        let path = if is_best { self.best_checkpoint_path() } else { self.checkpoint_path(epoch) };
65
66        // Write a marker file (real implementation would save model weights)
67        let timestamp = std::time::SystemTime::now()
68            .duration_since(std::time::UNIX_EPOCH)
69            .map(|d| d.as_secs())
70            .unwrap_or(0);
71        let info =
72            format!(r#"{{"epoch": {epoch}, "is_best": {is_best}, "timestamp": {timestamp}}}"#);
73        std::fs::write(&path, info).ok();
74
75        self.last_saved_epoch = Some(epoch);
76    }
77}
78
79impl TrainerCallback for CheckpointCallback {
80    fn on_epoch_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
81        let mut should_save = false;
82        let mut is_best = false;
83
84        // Check if we should save periodically
85        if let Some(interval) = self.save_every {
86            if (ctx.epoch + 1).is_multiple_of(interval) {
87                should_save = true;
88            }
89        }
90
91        // Check if this is the best model
92        let loss = ctx.val_loss.unwrap_or(ctx.loss);
93        if self.save_best && loss < self.best_loss {
94            self.best_loss = loss;
95            should_save = true;
96            is_best = true;
97        }
98
99        if should_save {
100            self.save_checkpoint(ctx.epoch, is_best);
101        }
102
103        CallbackAction::Continue
104    }
105
106    fn on_train_end(&mut self, ctx: &CallbackContext) {
107        // Save final checkpoint
108        self.save_checkpoint(ctx.epoch, false);
109    }
110
111    fn name(&self) -> &'static str {
112        "CheckpointCallback"
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    #[test]
121    fn test_checkpoint_callback_paths() {
122        let cb = CheckpointCallback::new("/tmp/checkpoints");
123        assert_eq!(
124            cb.checkpoint_path(5),
125            PathBuf::from("/tmp/checkpoints/checkpoint_epoch_5.json")
126        );
127        assert_eq!(
128            cb.best_checkpoint_path(),
129            PathBuf::from("/tmp/checkpoints/checkpoint_best.json")
130        );
131    }
132
133    #[test]
134    fn test_checkpoint_callback_save_every() {
135        let temp_dir = tempfile::tempdir().expect("temp file creation should succeed");
136        let mut cb = CheckpointCallback::new(temp_dir.path()).save_every(2);
137
138        let mut ctx = CallbackContext::default();
139        ctx.loss = 1.0;
140        cb.on_epoch_end(&ctx);
141
142        ctx.epoch = 1;
143        cb.on_epoch_end(&ctx);
144        assert_eq!(cb.last_saved_epoch, Some(1));
145    }
146
147    #[test]
148    fn test_checkpoint_callback_save_best_disabled() {
149        let temp_dir = tempfile::tempdir().expect("temp file creation should succeed");
150        let mut cb = CheckpointCallback::new(temp_dir.path()).save_best(false);
151
152        let mut ctx = CallbackContext::default();
153        ctx.loss = 0.1;
154        cb.on_epoch_end(&ctx);
155        assert!(cb.last_saved_epoch.is_none());
156    }
157
158    #[test]
159    fn test_checkpoint_callback_on_train_end() {
160        let temp_dir = tempfile::tempdir().expect("temp file creation should succeed");
161        let mut cb = CheckpointCallback::new(temp_dir.path());
162
163        let ctx = CallbackContext { epoch: 5, ..Default::default() };
164
165        cb.on_train_end(&ctx);
166        assert_eq!(cb.last_saved_epoch, Some(5));
167    }
168
169    #[test]
170    fn test_checkpoint_callback_name() {
171        let cb = CheckpointCallback::new("/tmp");
172        assert_eq!(cb.name(), "CheckpointCallback");
173    }
174
175    #[test]
176    fn test_checkpoint_callback_val_loss_for_best() {
177        let temp_dir = tempfile::tempdir().expect("temp file creation should succeed");
178        let mut cb = CheckpointCallback::new(temp_dir.path());
179
180        let mut ctx = CallbackContext::default();
181        ctx.loss = 1.0;
182        ctx.val_loss = Some(0.5);
183        cb.on_epoch_end(&ctx);
184        assert_eq!(cb.best_loss, 0.5);
185    }
186
187    #[test]
188    fn test_checkpoint_callback_clone() {
189        let cb = CheckpointCallback::new("/tmp/test");
190        let cloned = cb.clone();
191        assert_eq!(cb.checkpoint_dir, cloned.checkpoint_dir);
192    }
193}
194
195#[cfg(test)]
196mod proptests {
197    use super::*;
198    use proptest::prelude::*;
199
200    proptest! {
201        /// Checkpoint paths should be consistent
202        #[test]
203        fn checkpoint_paths_are_consistent(
204            epoch in 0usize..1000,
205        ) {
206            let cb = CheckpointCallback::new("/tmp/test");
207
208            // Should generate predictable paths
209            let path = cb.checkpoint_path(epoch);
210            let expected = format!("/tmp/test/checkpoint_epoch_{epoch}.json");
211            prop_assert_eq!(path, PathBuf::from(&expected));
212
213            // Best path should be constant
214            let best = cb.best_checkpoint_path();
215            prop_assert_eq!(best, PathBuf::from("/tmp/test/checkpoint_best.json"));
216        }
217    }
218}