use std::fs;
use std::io;
use std::path::{Path, PathBuf};
pub fn select_checkpoint(base_dir: &Path, criteria: &str) -> io::Result<Option<PathBuf>> {
let entries = fs::read_dir(base_dir)?;
let mut best_path: Option<PathBuf> = None;
let mut best_checkpoint_num: i64 = -1;
let mut best_loss: f64 = f64::MAX;
for entry in entries {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
if let Some(dir_name) = path.file_name().and_then(|n| n.to_str()) {
if let Some(checkpoint_num_str) = dir_name.strip_prefix("checkpoint-") {
if let Ok(checkpoint_num) = checkpoint_num_str.parse::<i64>() {
match criteria {
"latest" => {
if checkpoint_num > best_checkpoint_num {
best_checkpoint_num = checkpoint_num;
best_path = Some(path.clone());
}
}
"best" => {
let trainer_state_file = path.join("trainer_state.json");
let eval_loss = match fs::read_to_string(&trainer_state_file) {
Ok(contents) => {
match serde_json::from_str::<serde_json::Value>(&contents) {
Ok(json_val) => json_val
.get("metrics")
.and_then(|metrics| metrics.get("eval_loss"))
.and_then(|val| val.as_f64())
.unwrap_or(f64::MAX),
Err(_) => f64::MAX,
}
}
Err(_) => f64::MAX,
};
if eval_loss < best_loss {
best_loss = eval_loss;
best_path = Some(path.clone());
}
}
_ => {
}
}
}
}
}
}
}
Ok(best_path)
}