use crate::error::{ModelError, ModelResult};
use scirs2_core::ndarray::Array1;
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckpointMetadata {
pub epoch: usize,
pub step: usize,
pub val_loss: Option<f32>,
pub train_loss: Option<f32>,
pub best_val_loss: Option<f32>,
pub timestamp: String,
pub model_type: String,
pub metrics: std::collections::HashMap<String, f32>,
}
impl CheckpointMetadata {
pub fn new(epoch: usize, step: usize) -> Self {
Self {
epoch,
step,
val_loss: None,
train_loss: None,
best_val_loss: None,
timestamp: chrono::Utc::now().to_rfc3339(),
model_type: "unknown".to_string(),
metrics: std::collections::HashMap::new(),
}
}
pub fn with_val_loss(mut self, loss: f32) -> Self {
self.val_loss = Some(loss);
self
}
pub fn with_train_loss(mut self, loss: f32) -> Self {
self.train_loss = Some(loss);
self
}
pub fn with_metric(mut self, name: String, value: f32) -> Self {
self.metrics.insert(name, value);
self
}
}
#[derive(Debug, Clone)]
pub struct CheckpointManager {
pub checkpoint_dir: PathBuf,
pub max_checkpoints: usize,
pub keep_best: usize,
pub best_checkpoint: Option<CheckpointMetadata>,
recent_checkpoints: VecDeque<PathBuf>,
}
impl CheckpointManager {
pub fn new<P: AsRef<Path>>(checkpoint_dir: P) -> Self {
Self {
checkpoint_dir: checkpoint_dir.as_ref().to_path_buf(),
max_checkpoints: 5,
keep_best: 3,
best_checkpoint: None,
recent_checkpoints: VecDeque::new(),
}
}
pub fn max_checkpoints(mut self, max: usize) -> Self {
self.max_checkpoints = max;
self
}
pub fn keep_best(mut self, n: usize) -> Self {
self.keep_best = n;
self
}
pub fn checkpoint_path(&self, epoch: usize) -> PathBuf {
self.checkpoint_dir
.join(format!("checkpoint_epoch_{}.bin", epoch))
}
pub fn best_checkpoint_path(&self) -> PathBuf {
self.checkpoint_dir.join("best_checkpoint.bin")
}
pub fn save_metadata(&self, metadata: &CheckpointMetadata) -> ModelResult<()> {
let path = self
.checkpoint_dir
.join(format!("metadata_epoch_{}.json", metadata.epoch));
std::fs::create_dir_all(&self.checkpoint_dir).map_err(|e| {
ModelError::load_error(
"checkpoint save",
format!("Failed to create checkpoint directory: {}", e),
)
})?;
let json = serde_json::to_string_pretty(metadata).map_err(|e| {
ModelError::load_error(
"checkpoint save",
format!("Failed to serialize metadata: {}", e),
)
})?;
std::fs::write(&path, json).map_err(|e| {
ModelError::load_error(
"checkpoint save",
format!("Failed to write metadata: {}", e),
)
})?;
tracing::info!("Saved checkpoint metadata to {:?}", path);
Ok(())
}
pub fn load_metadata(&self, epoch: usize) -> ModelResult<CheckpointMetadata> {
let path = self
.checkpoint_dir
.join(format!("metadata_epoch_{}.json", epoch));
let json = std::fs::read_to_string(&path).map_err(|e| {
ModelError::load_error("checkpoint load", format!("Failed to read metadata: {}", e))
})?;
let metadata: CheckpointMetadata = serde_json::from_str(&json).map_err(|e| {
ModelError::load_error(
"checkpoint load",
format!("Failed to deserialize metadata: {}", e),
)
})?;
Ok(metadata)
}
pub fn is_best(&self, val_loss: f32) -> bool {
if let Some(ref best) = self.best_checkpoint {
if let Some(best_loss) = best.val_loss {
return val_loss < best_loss;
}
}
true
}
pub fn update_best(&mut self, metadata: CheckpointMetadata) {
if let Some(val_loss) = metadata.val_loss {
if self.is_best(val_loss) {
self.best_checkpoint = Some(metadata);
tracing::info!("New best checkpoint with val_loss: {}", val_loss);
}
}
}
pub fn save_weights(
&self,
weights: &Array1<f32>,
bias: f32,
step: usize,
) -> ModelResult<PathBuf> {
std::fs::create_dir_all(&self.checkpoint_dir).map_err(|e| {
ModelError::load_error(
"weight save",
format!("failed to create checkpoint directory: {e}"),
)
})?;
let path = self
.checkpoint_dir
.join(format!("weights_step_{step}.json"));
let weights_vec: Vec<f32> = weights.iter().copied().collect();
let payload = serde_json::json!({
"step": step,
"bias": bias,
"weights": weights_vec,
});
let json = serde_json::to_string_pretty(&payload).map_err(|e| {
ModelError::load_error("weight save", format!("serialisation failed: {e}"))
})?;
std::fs::write(&path, json)
.map_err(|e| ModelError::load_error("weight save", format!("write failed: {e}")))?;
tracing::info!("Saved weights checkpoint to {:?}", path);
Ok(path)
}
pub fn load_weights(path: &Path) -> ModelResult<(Array1<f32>, f32)> {
let json = std::fs::read_to_string(path)
.map_err(|e| ModelError::load_error("weight load", format!("read failed: {e}")))?;
let value: serde_json::Value = serde_json::from_str(&json).map_err(|e| {
ModelError::load_error("weight load", format!("JSON parse failed: {e}"))
})?;
let bias = value.get("bias").and_then(|v| v.as_f64()).ok_or_else(|| {
ModelError::load_error("weight load", "missing or invalid 'bias' field")
})? as f32;
let weights_arr = value
.get("weights")
.and_then(|v| v.as_array())
.ok_or_else(|| {
ModelError::load_error("weight load", "missing or invalid 'weights' field")
})?;
let weights: Vec<f32> = weights_arr
.iter()
.enumerate()
.map(|(i, v)| {
v.as_f64()
.ok_or_else(|| {
ModelError::load_error(
"weight load",
format!("weights[{i}] is not a number"),
)
})
.map(|x| x as f32)
})
.collect::<ModelResult<Vec<f32>>>()?;
Ok((Array1::from_vec(weights), bias))
}
pub fn list_weight_checkpoints(&self) -> ModelResult<Vec<(usize, PathBuf)>> {
if !self.checkpoint_dir.exists() {
return Ok(Vec::new());
}
let read_dir = std::fs::read_dir(&self.checkpoint_dir).map_err(|e| {
ModelError::load_error("weight list", format!("failed to read checkpoint dir: {e}"))
})?;
let mut results: Vec<(usize, PathBuf)> = Vec::new();
for entry in read_dir {
let entry = entry.map_err(|e| {
ModelError::load_error("weight list", format!("directory entry error: {e}"))
})?;
let path = entry.path();
if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
if let Some(rest) = name.strip_prefix("weights_step_") {
if let Some(step_str) = rest.strip_suffix(".json") {
if let Ok(step) = step_str.parse::<usize>() {
results.push((step, path));
}
}
}
}
}
results.sort_by_key(|(step, _)| *step);
Ok(results)
}
pub fn cleanup_old_checkpoints(&mut self) -> ModelResult<()> {
while self.recent_checkpoints.len() > self.max_checkpoints {
if let Some(old_checkpoint) = self.recent_checkpoints.pop_front() {
if old_checkpoint.exists() {
std::fs::remove_file(&old_checkpoint).map_err(|e| {
ModelError::load_error(
"checkpoint cleanup",
format!("Failed to remove old checkpoint: {}", e),
)
})?;
tracing::debug!("Removed old checkpoint: {:?}", old_checkpoint);
}
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EarlyStopping {
pub patience: usize,
pub min_delta: f32,
counter: usize,
best_loss: Option<f32>,
stopped: bool,
}
impl EarlyStopping {
pub fn new(patience: usize, min_delta: f32) -> Self {
Self {
patience,
min_delta,
counter: 0,
best_loss: None,
stopped: false,
}
}
pub fn should_stop(&mut self, current_loss: f32) -> bool {
if self.stopped {
return true;
}
if let Some(best_loss) = self.best_loss {
if current_loss < best_loss - self.min_delta {
self.best_loss = Some(current_loss);
self.counter = 0;
tracing::info!(
"Validation improved: {:.6} -> {:.6}",
best_loss,
current_loss
);
} else {
self.counter += 1;
tracing::debug!(
"No improvement for {} epochs (best: {:.6}, current: {:.6})",
self.counter,
best_loss,
current_loss
);
if self.counter >= self.patience {
self.stopped = true;
tracing::info!(
"Early stopping triggered after {} epochs without improvement",
self.patience
);
return true;
}
}
} else {
self.best_loss = Some(current_loss);
tracing::info!("Initial validation loss: {:.6}", current_loss);
}
false
}
pub fn reset(&mut self) {
self.counter = 0;
self.best_loss = None;
self.stopped = false;
}
pub fn is_stopped(&self) -> bool {
self.stopped
}
pub fn best_loss(&self) -> Option<f32> {
self.best_loss
}
}
#[derive(Debug, Clone)]
pub struct ValidationMetrics {
history: std::collections::HashMap<String, Vec<f32>>,
window_size: usize,
}
impl ValidationMetrics {
pub fn new() -> Self {
Self {
history: std::collections::HashMap::new(),
window_size: 10,
}
}
pub fn window_size(mut self, size: usize) -> Self {
self.window_size = size;
self
}
pub fn record(&mut self, name: &str, value: f32) {
self.history
.entry(name.to_string())
.or_default()
.push(value);
}
pub fn get_history(&self, name: &str) -> Option<&Vec<f32>> {
self.history.get(name)
}
pub fn get_latest(&self, name: &str) -> Option<f32> {
self.history.get(name).and_then(|h| h.last().copied())
}
pub fn get_moving_avg(&self, name: &str) -> Option<f32> {
self.history.get(name).and_then(|history| {
if history.is_empty() {
return None;
}
let start = history.len().saturating_sub(self.window_size);
let window = &history[start..];
let sum: f32 = window.iter().sum();
Some(sum / window.len() as f32)
})
}
pub fn is_improving(&self, name: &str, min_delta: f32) -> bool {
if let Some(history) = self.history.get(name) {
if history.len() < 2 {
return true;
}
let current = history[history.len() - 1];
let previous = history[history.len() - 2];
current < previous - min_delta
} else {
false
}
}
pub fn get_best(&self, name: &str) -> Option<f32> {
self.history.get(name).and_then(|h| {
h.iter()
.copied()
.fold(None, |min, x| Some(min.map_or(x, |m: f32| m.min(x))))
})
}
pub fn print_summary(&self) {
tracing::info!("=== Validation Metrics Summary ===");
for (name, history) in &self.history {
if let (Some(latest), Some(best)) = (history.last(), self.get_best(name)) {
let avg = self.get_moving_avg(name).unwrap_or(0.0);
tracing::info!(
"{}: latest={:.6}, best={:.6}, avg={:.6}",
name,
latest,
best,
avg
);
}
}
}
}
impl Default for ValidationMetrics {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GradientCheckpointConfig {
pub enabled: bool,
pub num_segments: usize,
pub cpu_offload: bool,
}
impl Default for GradientCheckpointConfig {
fn default() -> Self {
Self {
enabled: false,
num_segments: 4,
cpu_offload: false,
}
}
}
impl GradientCheckpointConfig {
pub fn new() -> Self {
Self::default()
}
pub fn enable(mut self) -> Self {
self.enabled = true;
self
}
pub fn segments(mut self, n: usize) -> Self {
self.num_segments = n;
self
}
pub fn cpu_offload(mut self, enable: bool) -> Self {
self.cpu_offload = enable;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_checkpoint_metadata() {
let meta = CheckpointMetadata::new(5, 1000)
.with_val_loss(0.5)
.with_train_loss(0.6)
.with_metric("accuracy".to_string(), 0.95);
assert_eq!(meta.epoch, 5);
assert_eq!(meta.step, 1000);
assert_eq!(meta.val_loss, Some(0.5));
assert_eq!(meta.train_loss, Some(0.6));
assert_eq!(meta.metrics.get("accuracy"), Some(&0.95));
}
#[test]
fn test_early_stopping() {
let mut early_stop = EarlyStopping::new(3, 0.01);
assert!(!early_stop.should_stop(1.0)); assert!(!early_stop.should_stop(0.9)); assert!(!early_stop.should_stop(0.85)); assert!(!early_stop.should_stop(0.851)); assert!(!early_stop.should_stop(0.852)); assert!(early_stop.should_stop(0.853)); }
#[test]
fn test_early_stopping_reset() {
let mut early_stop = EarlyStopping::new(2, 0.01);
early_stop.should_stop(1.0);
early_stop.should_stop(1.0);
early_stop.should_stop(1.0);
assert!(early_stop.is_stopped());
early_stop.reset();
assert!(!early_stop.is_stopped());
assert_eq!(early_stop.best_loss(), None);
}
#[test]
fn test_validation_metrics() {
let mut metrics = ValidationMetrics::new().window_size(3);
metrics.record("loss", 1.0);
metrics.record("loss", 0.9);
metrics.record("loss", 0.8);
metrics.record("loss", 0.7);
assert_eq!(metrics.get_latest("loss"), Some(0.7));
assert_eq!(metrics.get_best("loss"), Some(0.7));
let avg = metrics
.get_moving_avg("loss")
.expect("Failed to get moving avg");
assert!((avg - 0.8).abs() < 1e-6); }
#[test]
fn test_validation_metrics_improving() {
let mut metrics = ValidationMetrics::new();
metrics.record("loss", 1.0);
metrics.record("loss", 0.8);
assert!(metrics.is_improving("loss", 0.1));
metrics.record("loss", 0.79);
assert!(!metrics.is_improving("loss", 0.1)); }
#[test]
fn test_checkpoint_manager_is_best() {
let mut manager = CheckpointManager::new("/tmp/checkpoints");
assert!(manager.is_best(1.0));
let meta = CheckpointMetadata::new(1, 100).with_val_loss(1.0);
manager.update_best(meta);
assert!(!manager.is_best(1.1));
assert!(manager.is_best(0.9));
}
#[test]
fn test_gradient_checkpoint_config() {
let config = GradientCheckpointConfig::new()
.enable()
.segments(8)
.cpu_offload(true);
assert!(config.enabled);
assert_eq!(config.num_segments, 8);
assert!(config.cpu_offload);
}
}