use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use thiserror::Error;
use tokio::fs;
use tokio::io::AsyncWriteExt;
use tokio::sync::RwLock;
use tracing::{debug, info};
#[derive(Debug, Error)]
pub enum TrainingError {
#[error("Experiment not found: {0}")]
ExperimentNotFound(String),
#[error("Checkpoint not found: {0}")]
CheckpointNotFound(String),
#[error("Invalid checkpoint shard: {0}")]
InvalidShard(String),
#[error("Experiment already exists: {0}")]
ExperimentAlreadyExists(String),
#[error("Invalid configuration: {0}")]
InvalidConfig(String),
#[error("I/O error: {0}")]
IoError(#[from] std::io::Error),
#[error("Serialization error: {0}")]
SerializationError(#[from] serde_json::Error),
}
pub type TrainingResult<T> = Result<T, TrainingError>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExperimentConfig {
pub name: String,
pub description: Option<String>,
pub tags: Vec<String>,
pub hyperparameters: JsonValue,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Experiment {
pub id: String,
pub name: String,
pub description: Option<String>,
pub tags: Vec<String>,
pub hyperparameters: JsonValue,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub status: ExperimentStatus,
pub best_metrics: Option<JsonValue>,
pub checkpoint_count: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ExperimentStatus {
Running,
Completed,
Failed,
Stopped,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckpointConfig {
pub max_shard_size: u64,
pub compression_enabled: bool,
pub max_checkpoints: u32,
pub retention_policy: RetentionPolicy,
}
impl Default for CheckpointConfig {
fn default() -> Self {
Self {
max_shard_size: 1024 * 1024 * 1024, compression_enabled: true,
max_checkpoints: 5,
retention_policy: RetentionPolicy::KeepBest,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RetentionPolicy {
KeepAll,
KeepRecent,
KeepBest,
KeepEpochInterval,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint {
pub id: String,
pub experiment_id: String,
pub epoch: u64,
pub step: u64,
pub created_at: DateTime<Utc>,
pub metrics: JsonValue,
pub shard_count: u32,
pub total_size: u64,
pub has_optimizer_state: bool,
pub is_sharded: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetricsEntry {
pub experiment_id: String,
pub step: u64,
pub timestamp: DateTime<Utc>,
pub metrics: JsonValue,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HyperparameterSearchResult {
pub id: String,
pub experiment_ids: Vec<String>,
pub search_space: JsonValue,
pub best_params: JsonValue,
pub best_metric_value: Option<f64>,
pub optimization_metric: String,
pub trials: Vec<Trial>,
pub started_at: DateTime<Utc>,
pub completed_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Trial {
pub id: String,
pub params: JsonValue,
pub metrics: JsonValue,
pub status: TrialStatus,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TrialStatus {
Running,
Completed,
Failed,
}
pub struct TrainingManager {
base_path: PathBuf,
checkpoint_config: CheckpointConfig,
experiments: Arc<RwLock<HashMap<String, Experiment>>>,
searches: Arc<RwLock<HashMap<String, HyperparameterSearchResult>>>,
}
impl TrainingManager {
pub fn new(base_path: PathBuf) -> Self {
Self {
base_path,
checkpoint_config: CheckpointConfig::default(),
experiments: Arc::new(RwLock::new(HashMap::new())),
searches: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn with_checkpoint_config(mut self, config: CheckpointConfig) -> Self {
self.checkpoint_config = config;
self
}
async fn ensure_directories(&self) -> TrainingResult<()> {
fs::create_dir_all(&self.base_path).await?;
fs::create_dir_all(self.base_path.join("experiments")).await?;
fs::create_dir_all(self.base_path.join("checkpoints")).await?;
fs::create_dir_all(self.base_path.join("metrics")).await?;
fs::create_dir_all(self.base_path.join("searches")).await?;
Ok(())
}
pub async fn create_experiment(&self, config: ExperimentConfig) -> TrainingResult<Experiment> {
self.ensure_directories().await?;
let experiments = self.experiments.read().await;
if experiments.values().any(|e| e.name == config.name) {
return Err(TrainingError::ExperimentAlreadyExists(config.name));
}
drop(experiments);
let experiment = Experiment {
id: uuid::Uuid::new_v4().to_string(),
name: config.name.clone(),
description: config.description.clone(),
tags: config.tags.clone(),
hyperparameters: config.hyperparameters.clone(),
created_at: Utc::now(),
updated_at: Utc::now(),
status: ExperimentStatus::Running,
best_metrics: None,
checkpoint_count: 0,
};
let exp_path = self.base_path.join("experiments").join(&experiment.id);
fs::create_dir_all(&exp_path).await?;
let exp_file = exp_path.join("metadata.json");
let json = serde_json::to_string_pretty(&experiment)?;
fs::write(&exp_file, json).await?;
let mut experiments = self.experiments.write().await;
experiments.insert(experiment.id.clone(), experiment.clone());
info!(
"Created experiment: {} ({})",
experiment.name, experiment.id
);
Ok(experiment)
}
pub async fn get_experiment(&self, experiment_id: &str) -> TrainingResult<Experiment> {
let experiments = self.experiments.read().await;
if let Some(exp) = experiments.get(experiment_id) {
return Ok(exp.clone());
}
drop(experiments);
let exp_file = self
.base_path
.join("experiments")
.join(experiment_id)
.join("metadata.json");
if !exp_file.exists() {
return Err(TrainingError::ExperimentNotFound(experiment_id.to_string()));
}
let json = fs::read_to_string(&exp_file).await?;
let experiment: Experiment = serde_json::from_str(&json)?;
let mut experiments = self.experiments.write().await;
experiments.insert(experiment_id.to_string(), experiment.clone());
Ok(experiment)
}
pub async fn update_experiment_status(
&self,
experiment_id: &str,
status: ExperimentStatus,
) -> TrainingResult<()> {
let mut exp = self.get_experiment(experiment_id).await?;
exp.status = status;
exp.updated_at = Utc::now();
let exp_file = self
.base_path
.join("experiments")
.join(experiment_id)
.join("metadata.json");
let json = serde_json::to_string_pretty(&exp)?;
fs::write(&exp_file, json).await?;
let mut experiments = self.experiments.write().await;
experiments.insert(experiment_id.to_string(), exp);
Ok(())
}
pub async fn save_checkpoint(
&self,
experiment_id: &str,
epoch: u64,
model_state: Vec<u8>,
optimizer_state: Option<Vec<u8>>,
metrics: JsonValue,
) -> TrainingResult<Checkpoint> {
let mut exp = self.get_experiment(experiment_id).await?;
let checkpoint = Checkpoint {
id: uuid::Uuid::new_v4().to_string(),
experiment_id: experiment_id.to_string(),
epoch,
step: 0, created_at: Utc::now(),
metrics: metrics.clone(),
shard_count: 1,
total_size: (model_state.len() + optimizer_state.as_ref().map_or(0, |s| s.len()))
as u64,
has_optimizer_state: optimizer_state.is_some(),
is_sharded: false,
};
let ckpt_dir = self
.base_path
.join("checkpoints")
.join(experiment_id)
.join(&checkpoint.id);
fs::create_dir_all(&ckpt_dir).await?;
let model_path = ckpt_dir.join("model.bin");
fs::write(&model_path, &model_state).await?;
if let Some(opt_state) = optimizer_state {
let opt_path = ckpt_dir.join("optimizer.bin");
fs::write(&opt_path, &opt_state).await?;
}
let meta_path = ckpt_dir.join("metadata.json");
let json = serde_json::to_string_pretty(&checkpoint)?;
fs::write(&meta_path, json).await?;
exp.checkpoint_count += 1;
exp.updated_at = Utc::now();
if exp.best_metrics.is_none() {
exp.best_metrics = Some(metrics.clone());
}
let exp_file = self
.base_path
.join("experiments")
.join(experiment_id)
.join("metadata.json");
let json = serde_json::to_string_pretty(&exp)?;
fs::write(&exp_file, json).await?;
let mut experiments = self.experiments.write().await;
experiments.insert(experiment_id.to_string(), exp);
info!(
"Saved checkpoint {} for experiment {} at epoch {}",
checkpoint.id, experiment_id, epoch
);
self.apply_retention_policy(experiment_id).await?;
Ok(checkpoint)
}
pub async fn load_checkpoint(&self, checkpoint_id: &str) -> TrainingResult<LoadedCheckpoint> {
let mut found_experiment_id: Option<String> = None;
let checkpoints_dir = self.base_path.join("checkpoints");
if !checkpoints_dir.exists() {
return Err(TrainingError::CheckpointNotFound(checkpoint_id.to_string()));
}
let mut entries = fs::read_dir(&checkpoints_dir).await?;
while let Some(entry) = entries.next_entry().await? {
if entry.file_type().await?.is_dir() {
let exp_id = entry.file_name().to_string_lossy().to_string();
let ckpt_dir = entry.path().join(checkpoint_id);
if ckpt_dir.exists() {
found_experiment_id = Some(exp_id);
break;
}
}
}
let experiment_id = found_experiment_id
.ok_or_else(|| TrainingError::CheckpointNotFound(checkpoint_id.to_string()))?;
let ckpt_dir = checkpoints_dir.join(&experiment_id).join(checkpoint_id);
let meta_path = ckpt_dir.join("metadata.json");
let json = fs::read_to_string(&meta_path).await?;
let checkpoint: Checkpoint = serde_json::from_str(&json)?;
let model_path = ckpt_dir.join("model.bin");
let model_state = fs::read(&model_path).await?;
let optimizer_state = if checkpoint.has_optimizer_state {
let opt_path = ckpt_dir.join("optimizer.bin");
Some(fs::read(&opt_path).await?)
} else {
None
};
debug!(
"Loaded checkpoint {} from epoch {}",
checkpoint_id, checkpoint.epoch
);
Ok(LoadedCheckpoint {
checkpoint,
model_state,
optimizer_state,
})
}
pub async fn list_checkpoints(&self, experiment_id: &str) -> TrainingResult<Vec<Checkpoint>> {
let ckpt_dir = self.base_path.join("checkpoints").join(experiment_id);
if !ckpt_dir.exists() {
return Ok(Vec::new());
}
let mut checkpoints = Vec::new();
let mut entries = fs::read_dir(&ckpt_dir).await?;
while let Some(entry) = entries.next_entry().await? {
if entry.file_type().await?.is_dir() {
let meta_path = entry.path().join("metadata.json");
if meta_path.exists() {
let json = fs::read_to_string(&meta_path).await?;
if let Ok(ckpt) = serde_json::from_str::<Checkpoint>(&json) {
checkpoints.push(ckpt);
}
}
}
}
checkpoints.sort_by_key(|b| std::cmp::Reverse(b.epoch));
Ok(checkpoints)
}
pub async fn log_metrics(
&self,
experiment_id: &str,
step: u64,
metrics: JsonValue,
) -> TrainingResult<()> {
self.get_experiment(experiment_id).await?;
let entry = MetricsEntry {
experiment_id: experiment_id.to_string(),
step,
timestamp: Utc::now(),
metrics,
};
let metrics_dir = self.base_path.join("metrics").join(experiment_id);
fs::create_dir_all(&metrics_dir).await?;
let metrics_file = metrics_dir.join("metrics.jsonl");
let json = serde_json::to_string(&entry)? + "\n";
let mut file = fs::OpenOptions::new()
.create(true)
.append(true)
.open(&metrics_file)
.await?;
file.write_all(json.as_bytes()).await?;
file.flush().await?;
debug!(
"Logged metrics for experiment {} at step {}",
experiment_id, step
);
Ok(())
}
pub async fn get_metrics(&self, experiment_id: &str) -> TrainingResult<Vec<MetricsEntry>> {
let metrics_file = self
.base_path
.join("metrics")
.join(experiment_id)
.join("metrics.jsonl");
if !metrics_file.exists() {
return Ok(Vec::new());
}
let content = fs::read_to_string(&metrics_file).await?;
let mut metrics = Vec::new();
for line in content.lines() {
if let Ok(entry) = serde_json::from_str::<MetricsEntry>(line) {
metrics.push(entry);
}
}
Ok(metrics)
}
async fn apply_retention_policy(&self, experiment_id: &str) -> TrainingResult<()> {
if self.checkpoint_config.max_checkpoints == 0 {
return Ok(()); }
let checkpoints = self.list_checkpoints(experiment_id).await?;
if checkpoints.len() as u32 <= self.checkpoint_config.max_checkpoints {
return Ok(()); }
let _to_remove = checkpoints.len() - self.checkpoint_config.max_checkpoints as usize;
let checkpoints_to_remove = match self.checkpoint_config.retention_policy {
RetentionPolicy::KeepAll => Vec::new(),
RetentionPolicy::KeepRecent => {
checkpoints
.iter()
.skip(self.checkpoint_config.max_checkpoints as usize)
.collect()
}
RetentionPolicy::KeepBest => {
checkpoints
.iter()
.skip(self.checkpoint_config.max_checkpoints as usize)
.collect()
}
RetentionPolicy::KeepEpochInterval => {
checkpoints
.iter()
.skip(self.checkpoint_config.max_checkpoints as usize)
.collect()
}
};
for ckpt in checkpoints_to_remove {
let ckpt_dir = self
.base_path
.join("checkpoints")
.join(experiment_id)
.join(&ckpt.id);
if ckpt_dir.exists() {
fs::remove_dir_all(&ckpt_dir).await?;
info!(
"Removed old checkpoint {} for experiment {}",
ckpt.id, experiment_id
);
}
}
Ok(())
}
pub async fn create_search(
&self,
search_space: JsonValue,
optimization_metric: String,
) -> TrainingResult<HyperparameterSearchResult> {
self.ensure_directories().await?;
let search = HyperparameterSearchResult {
id: uuid::Uuid::new_v4().to_string(),
experiment_ids: Vec::new(),
search_space,
best_params: JsonValue::Null,
best_metric_value: None,
optimization_metric,
trials: Vec::new(),
started_at: Utc::now(),
completed_at: None,
};
let search_file = self
.base_path
.join("searches")
.join(format!("{}.json", search.id));
let json = serde_json::to_string_pretty(&search)?;
fs::write(&search_file, json).await?;
let mut searches = self.searches.write().await;
searches.insert(search.id.clone(), search.clone());
info!("Created hyperparameter search: {}", search.id);
Ok(search)
}
pub async fn add_trial(
&self,
search_id: &str,
params: JsonValue,
metrics: JsonValue,
status: TrialStatus,
) -> TrainingResult<()> {
let search_file = self
.base_path
.join("searches")
.join(format!("{}.json", search_id));
if !search_file.exists() {
return Err(TrainingError::ExperimentNotFound(search_id.to_string()));
}
let json = fs::read_to_string(&search_file).await?;
let mut search: HyperparameterSearchResult = serde_json::from_str(&json)?;
let trial = Trial {
id: uuid::Uuid::new_v4().to_string(),
params,
metrics,
status,
};
search.trials.push(trial);
let json = serde_json::to_string_pretty(&search)?;
fs::write(&search_file, json).await?;
let mut searches = self.searches.write().await;
searches.insert(search_id.to_string(), search);
Ok(())
}
pub async fn get_search(&self, search_id: &str) -> TrainingResult<HyperparameterSearchResult> {
{
let searches = self.searches.read().await;
if let Some(search) = searches.get(search_id) {
return Ok(search.clone());
}
}
let search_file = self
.base_path
.join("searches")
.join(format!("{}.json", search_id));
if !search_file.exists() {
return Err(TrainingError::ExperimentNotFound(search_id.to_string()));
}
let json = fs::read_to_string(&search_file).await?;
let search: HyperparameterSearchResult = serde_json::from_str(&json)?;
let mut searches = self.searches.write().await;
searches.insert(search_id.to_string(), search.clone());
Ok(search)
}
}
#[derive(Debug)]
pub struct LoadedCheckpoint {
pub checkpoint: Checkpoint,
pub model_state: Vec<u8>,
pub optimizer_state: Option<Vec<u8>>,
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn setup_manager() -> (TrainingManager, TempDir) {
let temp_dir = TempDir::new().expect("failed to create temp dir");
let manager = TrainingManager::new(temp_dir.path().to_path_buf());
(manager, temp_dir)
}
#[tokio::test]
async fn test_create_experiment() {
let (manager, _temp) = setup_manager();
let config = ExperimentConfig {
name: "test-exp".to_string(),
description: Some("Test experiment".to_string()),
tags: vec!["test".to_string()],
hyperparameters: serde_json::json!({"lr": 0.001}),
};
let exp = manager
.create_experiment(config)
.await
.expect("create experiment should succeed");
assert_eq!(exp.name, "test-exp");
assert_eq!(exp.status, ExperimentStatus::Running);
assert_eq!(exp.checkpoint_count, 0);
}
#[tokio::test]
async fn test_save_and_load_checkpoint() {
let (manager, _temp) = setup_manager();
let config = ExperimentConfig {
name: "ckpt-test".to_string(),
description: None,
tags: vec![],
hyperparameters: serde_json::json!({}),
};
let exp = manager
.create_experiment(config)
.await
.expect("create experiment should succeed");
let model_state = b"model_data".to_vec();
let optimizer_state = Some(b"optimizer_data".to_vec());
let metrics = serde_json::json!({"loss": 0.5});
let ckpt = manager
.save_checkpoint(
&exp.id,
1,
model_state.clone(),
optimizer_state.clone(),
metrics,
)
.await
.expect("save checkpoint should succeed");
assert_eq!(ckpt.epoch, 1);
assert!(ckpt.has_optimizer_state);
let loaded = manager
.load_checkpoint(&ckpt.id)
.await
.expect("load checkpoint should succeed");
assert_eq!(loaded.model_state, model_state);
assert_eq!(loaded.optimizer_state, optimizer_state);
assert_eq!(loaded.checkpoint.epoch, 1);
}
#[tokio::test]
async fn test_log_metrics() {
let (manager, _temp) = setup_manager();
let config = ExperimentConfig {
name: "metrics-test".to_string(),
description: None,
tags: vec![],
hyperparameters: serde_json::json!({}),
};
let exp = manager
.create_experiment(config)
.await
.expect("create experiment should succeed");
manager
.log_metrics(&exp.id, 1, serde_json::json!({"loss": 0.5}))
.await
.expect("log metrics step 1 should succeed");
manager
.log_metrics(&exp.id, 2, serde_json::json!({"loss": 0.4}))
.await
.expect("log metrics step 2 should succeed");
let metrics = manager
.get_metrics(&exp.id)
.await
.expect("get metrics should succeed");
assert_eq!(metrics.len(), 2);
assert_eq!(metrics[0].step, 1);
assert_eq!(metrics[1].step, 2);
}
#[tokio::test]
async fn test_list_checkpoints() {
let (manager, _temp) = setup_manager();
let config = ExperimentConfig {
name: "list-test".to_string(),
description: None,
tags: vec![],
hyperparameters: serde_json::json!({}),
};
let exp = manager
.create_experiment(config)
.await
.expect("create experiment should succeed");
for epoch in 1..=3 {
manager
.save_checkpoint(
&exp.id,
epoch,
b"model".to_vec(),
None,
serde_json::json!({"epoch": epoch}),
)
.await
.expect("save checkpoint should succeed");
}
let checkpoints = manager
.list_checkpoints(&exp.id)
.await
.expect("list checkpoints should succeed");
assert_eq!(checkpoints.len(), 3);
assert_eq!(checkpoints[0].epoch, 3);
assert_eq!(checkpoints[1].epoch, 2);
assert_eq!(checkpoints[2].epoch, 1);
}
#[tokio::test]
async fn test_retention_policy() {
let (mut manager, _temp) = setup_manager();
manager.checkpoint_config.max_checkpoints = 2;
manager.checkpoint_config.retention_policy = RetentionPolicy::KeepRecent;
let config = ExperimentConfig {
name: "retention-test".to_string(),
description: None,
tags: vec![],
hyperparameters: serde_json::json!({}),
};
let exp = manager
.create_experiment(config)
.await
.expect("create experiment should succeed");
for epoch in 1..=4 {
manager
.save_checkpoint(
&exp.id,
epoch,
b"model".to_vec(),
None,
serde_json::json!({"epoch": epoch}),
)
.await
.expect("save checkpoint should succeed");
}
let checkpoints = manager
.list_checkpoints(&exp.id)
.await
.expect("list checkpoints should succeed");
assert_eq!(checkpoints.len(), 2);
assert_eq!(checkpoints[0].epoch, 4);
assert_eq!(checkpoints[1].epoch, 3);
}
#[tokio::test]
async fn test_update_experiment_status() {
let (manager, _temp) = setup_manager();
let config = ExperimentConfig {
name: "status-test".to_string(),
description: None,
tags: vec![],
hyperparameters: serde_json::json!({}),
};
let exp = manager
.create_experiment(config)
.await
.expect("create experiment should succeed");
assert_eq!(exp.status, ExperimentStatus::Running);
manager
.update_experiment_status(&exp.id, ExperimentStatus::Completed)
.await
.expect("update status should succeed");
let updated = manager
.get_experiment(&exp.id)
.await
.expect("get experiment should succeed");
assert_eq!(updated.status, ExperimentStatus::Completed);
}
#[tokio::test]
async fn test_hyperparameter_search() {
let (manager, _temp) = setup_manager();
let search_space = serde_json::json!({
"learning_rate": [0.001, 0.01, 0.1],
"batch_size": [16, 32, 64],
});
let search = manager
.create_search(search_space.clone(), "accuracy".to_string())
.await
.expect("create search should succeed");
assert_eq!(search.optimization_metric, "accuracy");
assert_eq!(search.trials.len(), 0);
manager
.add_trial(
&search.id,
serde_json::json!({"learning_rate": 0.001, "batch_size": 32}),
serde_json::json!({"accuracy": 0.85}),
TrialStatus::Completed,
)
.await
.expect("add trial should succeed");
}
#[tokio::test]
async fn test_experiment_not_found() {
let (manager, _temp) = setup_manager();
let result = manager.get_experiment("nonexistent").await;
assert!(result.is_err());
let err = result.expect_err("should fail for nonexistent experiment");
assert!(matches!(err, TrainingError::ExperimentNotFound(_)));
}
#[tokio::test]
async fn test_checkpoint_not_found() {
let (manager, _temp) = setup_manager();
let result = manager.load_checkpoint("nonexistent").await;
assert!(result.is_err());
let err = result.expect_err("should fail for nonexistent checkpoint");
assert!(matches!(err, TrainingError::CheckpointNotFound(_)));
}
#[tokio::test]
async fn test_duplicate_experiment_name() {
let (manager, _temp) = setup_manager();
let config = ExperimentConfig {
name: "duplicate".to_string(),
description: None,
tags: vec![],
hyperparameters: serde_json::json!({}),
};
manager
.create_experiment(config.clone())
.await
.expect("first create should succeed");
let result = manager.create_experiment(config).await;
assert!(result.is_err());
let err = result.expect_err("duplicate experiment should fail");
assert!(matches!(err, TrainingError::ExperimentAlreadyExists(_)));
}
}