use std::collections::BTreeMap;
use std::fs::{File, OpenOptions};
use std::io::{BufRead, BufReader, Write};
use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MetricsEntry {
pub step: u32,
pub wall_time_ms: f64,
pub loss: f32,
pub grad_norm: f32,
pub throughput_samples_per_sec: f32,
#[serde(default)]
pub extra: BTreeMap<String, f32>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct MetricsSummary {
pub total_steps: u32,
pub total_wall_time_ms: f64,
pub loss_min: f32,
pub loss_mean: f32,
pub loss_max: f32,
}
pub struct MetricsLog {
entries: Vec<MetricsEntry>,
path: PathBuf,
}
impl MetricsLog {
pub fn open(path: &Path) -> std::io::Result<Self> {
if let Some(parent) = path.parent() {
if !parent.as_os_str().is_empty() {
std::fs::create_dir_all(parent)?;
}
}
let mut entries: Vec<MetricsEntry> = Vec::new();
if path.exists() {
let file = File::open(path)?;
let reader = BufReader::new(file);
for (lineno, line) in reader.lines().enumerate() {
let line = line?;
if line.trim().is_empty() {
continue;
}
let entry: MetricsEntry = serde_json::from_str(&line).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("metrics_log parse error at line {}: {e}", lineno + 1),
)
})?;
entries.push(entry);
}
}
Ok(Self {
entries,
path: path.to_path_buf(),
})
}
pub fn record(&mut self, entry: MetricsEntry) -> std::io::Result<()> {
let line = serde_json::to_string(&entry)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
let mut f = OpenOptions::new()
.create(true)
.append(true)
.open(&self.path)?;
writeln!(f, "{}", line)?;
self.entries.push(entry);
Ok(())
}
pub fn summary(&self) -> MetricsSummary {
if self.entries.is_empty() {
return MetricsSummary {
total_steps: 0,
total_wall_time_ms: 0.0,
loss_min: 0.0,
loss_mean: 0.0,
loss_max: 0.0,
};
}
let mut loss_min = f32::INFINITY;
let mut loss_max = f32::NEG_INFINITY;
let mut loss_sum: f64 = 0.0;
for e in &self.entries {
if e.loss < loss_min {
loss_min = e.loss;
}
if e.loss > loss_max {
loss_max = e.loss;
}
loss_sum += e.loss as f64;
}
let n = self.entries.len() as f64;
MetricsSummary {
total_steps: self.entries.len() as u32,
total_wall_time_ms: self.entries.last().map(|e| e.wall_time_ms).unwrap_or(0.0),
loss_min,
loss_mean: (loss_sum / n) as f32,
loss_max,
}
}
pub fn iter(&self) -> impl Iterator<Item = &MetricsEntry> {
self.entries.iter()
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
pub fn grad_norm(grads: &[f32]) -> f32 {
let mut sum_sq: f64 = 0.0;
for &g in grads {
let g = g as f64;
sum_sq += g * g;
}
sum_sq.sqrt() as f32
}