1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
//! Model checkpoint callback implementation
use crate::callbacks::{Callback, CallbackContext, CallbackTiming};
use crate::error::Result;
use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
use std::path::{Path, PathBuf};
/// Model checkpoint callback that saves the model after every epoch
/// and optionally only saves the best models based on a monitored metric.
pub struct ModelCheckpoint<F: Float + Debug + ScalarOperand> {
/// Directory to save the model
filepath: PathBuf,
/// Whether to save only the best model based on the monitored metric
save_best_only: bool,
/// Whether to monitor validation loss (true) or training loss (false)
monitor_val_loss: bool,
/// Whether to monitor if values are decreasing (lower is better) or increasing (higher is better)
monitor_decrease: bool,
/// Best value of the monitored metric so far
best_value: Option<F>,
}
impl<F: Float + Debug + ScalarOperand + NumAssign> ModelCheckpoint<F> {
/// Create a new model checkpoint callback
///
/// # Arguments
/// * `filepath` - Directory or file path to save the model
/// * `save_best_only` - Whether to save only the best model based on the monitored metric
#[allow(dead_code)]
pub fn new<P: AsRef<Path>>(filepath: P, save_best_only: bool) -> Self {
Self {
filepath: filepath.as_ref().to_path_buf(),
save_best_only,
monitor_val_loss: true,
monitor_decrease: true,
best_value: None,
}
}
/// Configure to monitor training loss instead of validation loss
pub fn monitor_train_loss(mut self) -> Self {
self.monitor_val_loss = false;
self
}
/// Configure to monitor if values are increasing (higher is better)
/// Default is monitoring decreases (lower is better)
pub fn monitor_increase(mut self) -> Self {
self.monitor_decrease = false;
self
}
}
impl<F: Float + Debug + ScalarOperand + NumAssign> Callback<F> for ModelCheckpoint<F> {
fn on_event(&mut self, timing: CallbackTiming, context: &mut CallbackContext<F>) -> Result<()> {
if timing == CallbackTiming::AfterEpoch {
let should_save = if self.save_best_only {
// Get the monitored value
let current_value = if self.monitor_val_loss {
context.val_loss
} else {
context.epoch_loss
};
if let Some(current) = current_value {
match self.best_value {
None => {
// First epoch, save the model
self.best_value = Some(current);
true
}
Some(best) => {
// Check if there is improvement
let improved = if self.monitor_decrease {
// Lower is better
current < best
} else {
// Higher is better
current > best
};
if improved {
// Update best value and save the model
self.best_value = Some(current);
true
} else {
false
}
}
}
} else {
// No value to monitor, don't save
false
}
} else {
// Always save
true
};
if should_save {
let epoch = context.epoch;
let epoch_display = epoch + 1; // Convert to 1-based for display
let filepath = if self.filepath.is_dir() {
self.filepath
.join(format!("model_epoch_{}.pth", epoch_display))
} else {
self.filepath.clone()
};
println!("Saving model to: {}", filepath.display());
// In a real implementation, we'd save the model here
// model.save(filepath);
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_checkpoint_creation() {
// Test creating with default values
let checkpoint = ModelCheckpoint::<f32>::new("test_path", true);
assert_eq!(
checkpoint.filepath.to_str().expect("Operation failed"),
"test_path"
);
assert!(checkpoint.save_best_only);
assert!(checkpoint.monitor_val_loss);
assert!(checkpoint.monitor_decrease);
assert!(checkpoint.best_value.is_none());
// Test monitor_train_loss
let checkpoint = ModelCheckpoint::<f32>::new("test_path", true).monitor_train_loss();
assert!(!checkpoint.monitor_val_loss);
// Test monitor_increase
let checkpoint = ModelCheckpoint::<f32>::new("test_path", true).monitor_increase();
assert!(!checkpoint.monitor_decrease);
}
}