use std::path::PathBuf;
use super::traits::{CallbackAction, CallbackContext, TrainerCallback};
#[derive(Clone, Debug)]
pub struct CheckpointCallback {
checkpoint_dir: PathBuf,
save_every: Option<usize>,
save_best: bool,
best_loss: f32,
pub(crate) last_saved_epoch: Option<usize>,
}
impl CheckpointCallback {
pub fn new(checkpoint_dir: impl Into<PathBuf>) -> Self {
Self {
checkpoint_dir: checkpoint_dir.into(),
save_every: None,
save_best: true,
best_loss: f32::INFINITY,
last_saved_epoch: None,
}
}
pub fn save_every(mut self, epochs: usize) -> Self {
self.save_every = Some(epochs);
self
}
pub fn save_best(mut self, save: bool) -> Self {
self.save_best = save;
self
}
pub fn checkpoint_path(&self, epoch: usize) -> PathBuf {
self.checkpoint_dir.join(format!("checkpoint_epoch_{epoch}.json"))
}
pub fn best_checkpoint_path(&self) -> PathBuf {
self.checkpoint_dir.join("checkpoint_best.json")
}
fn save_checkpoint(&mut self, epoch: usize, is_best: bool) {
std::fs::create_dir_all(&self.checkpoint_dir).ok();
let path = if is_best { self.best_checkpoint_path() } else { self.checkpoint_path(epoch) };
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let info =
format!(r#"{{"epoch": {epoch}, "is_best": {is_best}, "timestamp": {timestamp}}}"#);
std::fs::write(&path, info).ok();
self.last_saved_epoch = Some(epoch);
}
}
impl TrainerCallback for CheckpointCallback {
fn on_epoch_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
let mut should_save = false;
let mut is_best = false;
if let Some(interval) = self.save_every {
if (ctx.epoch + 1).is_multiple_of(interval) {
should_save = true;
}
}
let loss = ctx.val_loss.unwrap_or(ctx.loss);
if self.save_best && loss < self.best_loss {
self.best_loss = loss;
should_save = true;
is_best = true;
}
if should_save {
self.save_checkpoint(ctx.epoch, is_best);
}
CallbackAction::Continue
}
fn on_train_end(&mut self, ctx: &CallbackContext) {
self.save_checkpoint(ctx.epoch, false);
}
fn name(&self) -> &'static str {
"CheckpointCallback"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_checkpoint_callback_paths() {
let cb = CheckpointCallback::new("/tmp/checkpoints");
assert_eq!(
cb.checkpoint_path(5),
PathBuf::from("/tmp/checkpoints/checkpoint_epoch_5.json")
);
assert_eq!(
cb.best_checkpoint_path(),
PathBuf::from("/tmp/checkpoints/checkpoint_best.json")
);
}
#[test]
fn test_checkpoint_callback_save_every() {
let temp_dir = tempfile::tempdir().expect("temp file creation should succeed");
let mut cb = CheckpointCallback::new(temp_dir.path()).save_every(2);
let mut ctx = CallbackContext::default();
ctx.loss = 1.0;
cb.on_epoch_end(&ctx);
ctx.epoch = 1;
cb.on_epoch_end(&ctx);
assert_eq!(cb.last_saved_epoch, Some(1));
}
#[test]
fn test_checkpoint_callback_save_best_disabled() {
let temp_dir = tempfile::tempdir().expect("temp file creation should succeed");
let mut cb = CheckpointCallback::new(temp_dir.path()).save_best(false);
let mut ctx = CallbackContext::default();
ctx.loss = 0.1;
cb.on_epoch_end(&ctx);
assert!(cb.last_saved_epoch.is_none());
}
#[test]
fn test_checkpoint_callback_on_train_end() {
let temp_dir = tempfile::tempdir().expect("temp file creation should succeed");
let mut cb = CheckpointCallback::new(temp_dir.path());
let ctx = CallbackContext { epoch: 5, ..Default::default() };
cb.on_train_end(&ctx);
assert_eq!(cb.last_saved_epoch, Some(5));
}
#[test]
fn test_checkpoint_callback_name() {
let cb = CheckpointCallback::new("/tmp");
assert_eq!(cb.name(), "CheckpointCallback");
}
#[test]
fn test_checkpoint_callback_val_loss_for_best() {
let temp_dir = tempfile::tempdir().expect("temp file creation should succeed");
let mut cb = CheckpointCallback::new(temp_dir.path());
let mut ctx = CallbackContext::default();
ctx.loss = 1.0;
ctx.val_loss = Some(0.5);
cb.on_epoch_end(&ctx);
assert_eq!(cb.best_loss, 0.5);
}
#[test]
fn test_checkpoint_callback_clone() {
let cb = CheckpointCallback::new("/tmp/test");
let cloned = cb.clone();
assert_eq!(cb.checkpoint_dir, cloned.checkpoint_dir);
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn checkpoint_paths_are_consistent(
epoch in 0usize..1000,
) {
let cb = CheckpointCallback::new("/tmp/test");
let path = cb.checkpoint_path(epoch);
let expected = format!("/tmp/test/checkpoint_epoch_{epoch}.json");
prop_assert_eq!(path, PathBuf::from(&expected));
let best = cb.best_checkpoint_path();
prop_assert_eq!(best, PathBuf::from("/tmp/test/checkpoint_best.json"));
}
}
}