use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct TrainingMetrics {
pub name: String,
pub total_examples: usize,
pub training_sessions: u64,
pub patterns_learned: usize,
pub quality_samples: Vec<f32>,
pub validation_quality: Option<f32>,
pub performance: PerformanceMetrics,
}
impl TrainingMetrics {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
..Default::default()
}
}
pub fn add_quality_sample(&mut self, quality: f32) {
self.quality_samples.push(quality);
if self.quality_samples.len() > 10000 {
self.quality_samples.remove(0);
}
}
pub fn avg_quality(&self) -> f32 {
if self.quality_samples.is_empty() {
0.0
} else {
self.quality_samples.iter().sum::<f32>() / self.quality_samples.len() as f32
}
}
pub fn quality_percentile(&self, percentile: f32) -> f32 {
if self.quality_samples.is_empty() {
return 0.0;
}
let mut sorted = self.quality_samples.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let idx = ((percentile / 100.0) * (sorted.len() - 1) as f32) as usize;
sorted[idx.min(sorted.len() - 1)]
}
pub fn quality_stats(&self) -> QualityMetrics {
if self.quality_samples.is_empty() {
return QualityMetrics::default();
}
let avg = self.avg_quality();
let min = self
.quality_samples
.iter()
.cloned()
.fold(f32::MAX, f32::min);
let max = self
.quality_samples
.iter()
.cloned()
.fold(f32::MIN, f32::max);
let variance = self
.quality_samples
.iter()
.map(|q| (q - avg).powi(2))
.sum::<f32>()
/ self.quality_samples.len() as f32;
let std_dev = variance.sqrt();
QualityMetrics {
avg,
min,
max,
std_dev,
p25: self.quality_percentile(25.0),
p50: self.quality_percentile(50.0),
p75: self.quality_percentile(75.0),
p95: self.quality_percentile(95.0),
sample_count: self.quality_samples.len(),
}
}
pub fn reset(&mut self) {
self.total_examples = 0;
self.training_sessions = 0;
self.patterns_learned = 0;
self.quality_samples.clear();
self.validation_quality = None;
self.performance = PerformanceMetrics::default();
}
pub fn merge(&mut self, other: &TrainingMetrics) {
self.total_examples += other.total_examples;
self.training_sessions += other.training_sessions;
self.patterns_learned = other.patterns_learned; self.quality_samples.extend(&other.quality_samples);
if self.quality_samples.len() > 10000 {
let excess = self.quality_samples.len() - 10000;
self.quality_samples.drain(0..excess);
}
}
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct QualityMetrics {
pub avg: f32,
pub min: f32,
pub max: f32,
pub std_dev: f32,
pub p25: f32,
pub p50: f32,
pub p75: f32,
pub p95: f32,
pub sample_count: usize,
}
impl std::fmt::Display for QualityMetrics {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"avg={:.4}, std={:.4}, min={:.4}, max={:.4}, p50={:.4}, p95={:.4} (n={})",
self.avg, self.std_dev, self.min, self.max, self.p50, self.p95, self.sample_count
)
}
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct PerformanceMetrics {
pub total_training_secs: f64,
pub avg_batch_time_ms: f64,
pub avg_example_time_us: f64,
pub peak_memory_mb: usize,
pub examples_per_sec: f64,
pub pattern_extraction_ms: f64,
}
impl PerformanceMetrics {
pub fn calculate_throughput(&mut self, examples: usize, duration_secs: f64) {
if duration_secs > 0.0 {
self.examples_per_sec = examples as f64 / duration_secs;
self.avg_example_time_us = (duration_secs * 1_000_000.0) / examples as f64;
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct EpochStats {
pub epoch: usize,
pub examples_processed: usize,
pub avg_quality: f32,
pub duration_secs: f64,
}
impl std::fmt::Display for EpochStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Epoch {}: {} examples, avg_quality={:.4}, {:.2}s",
self.epoch + 1,
self.examples_processed,
self.avg_quality,
self.duration_secs
)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TrainingResult {
pub pipeline_name: String,
pub epochs_completed: usize,
pub total_examples: usize,
pub patterns_learned: usize,
pub final_avg_quality: f32,
pub total_duration_secs: f64,
pub epoch_stats: Vec<EpochStats>,
pub validation_quality: Option<f32>,
}
impl TrainingResult {
pub fn examples_per_sec(&self) -> f64 {
if self.total_duration_secs > 0.0 {
self.total_examples as f64 / self.total_duration_secs
} else {
0.0
}
}
pub fn avg_epoch_duration(&self) -> f64 {
if self.epochs_completed > 0 {
self.total_duration_secs / self.epochs_completed as f64
} else {
0.0
}
}
pub fn quality_improved(&self) -> bool {
if self.epoch_stats.len() < 2 {
return false;
}
let first = self.epoch_stats.first().unwrap().avg_quality;
let last = self.epoch_stats.last().unwrap().avg_quality;
last > first
}
pub fn quality_improvement(&self) -> f32 {
if self.epoch_stats.len() < 2 {
return 0.0;
}
let first = self.epoch_stats.first().unwrap().avg_quality;
let last = self.epoch_stats.last().unwrap().avg_quality;
last - first
}
}
impl std::fmt::Display for TrainingResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"TrainingResult(pipeline={}, epochs={}, examples={}, patterns={}, \
final_quality={:.4}, duration={:.2}s, throughput={:.1}/s)",
self.pipeline_name,
self.epochs_completed,
self.total_examples,
self.patterns_learned,
self.final_avg_quality,
self.total_duration_secs,
self.examples_per_sec()
)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[allow(dead_code)]
pub struct TrainingComparison {
pub baseline_name: String,
pub comparison_name: String,
pub quality_diff: f32,
pub quality_improvement_pct: f32,
pub throughput_diff: f64,
pub duration_diff: f64,
}
#[allow(dead_code)]
impl TrainingComparison {
pub fn compare(baseline: &TrainingResult, comparison: &TrainingResult) -> Self {
let quality_diff = comparison.final_avg_quality - baseline.final_avg_quality;
let quality_improvement_pct = if baseline.final_avg_quality > 0.0 {
(quality_diff / baseline.final_avg_quality) * 100.0
} else {
0.0
};
Self {
baseline_name: baseline.pipeline_name.clone(),
comparison_name: comparison.pipeline_name.clone(),
quality_diff,
quality_improvement_pct,
throughput_diff: comparison.examples_per_sec() - baseline.examples_per_sec(),
duration_diff: comparison.total_duration_secs - baseline.total_duration_secs,
}
}
}
impl std::fmt::Display for TrainingComparison {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let quality_sign = if self.quality_diff >= 0.0 { "+" } else { "" };
let throughput_sign = if self.throughput_diff >= 0.0 { "+" } else { "" };
write!(
f,
"Comparison {} vs {}: quality {}{:.4} ({}{:.1}%), throughput {}{:.1}/s",
self.comparison_name,
self.baseline_name,
quality_sign,
self.quality_diff,
quality_sign,
self.quality_improvement_pct,
throughput_sign,
self.throughput_diff
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_metrics_creation() {
let metrics = TrainingMetrics::new("test");
assert_eq!(metrics.name, "test");
assert_eq!(metrics.total_examples, 0);
}
#[test]
fn test_quality_samples() {
let mut metrics = TrainingMetrics::new("test");
for i in 0..10 {
metrics.add_quality_sample(i as f32 / 10.0);
}
assert_eq!(metrics.quality_samples.len(), 10);
assert!((metrics.avg_quality() - 0.45).abs() < 0.01);
}
#[test]
fn test_quality_percentiles() {
let mut metrics = TrainingMetrics::new("test");
for i in 0..100 {
metrics.add_quality_sample(i as f32 / 100.0);
}
assert!((metrics.quality_percentile(50.0) - 0.5).abs() < 0.02);
assert!((metrics.quality_percentile(95.0) - 0.95).abs() < 0.02);
}
#[test]
fn test_quality_stats() {
let mut metrics = TrainingMetrics::new("test");
metrics.add_quality_sample(0.5);
metrics.add_quality_sample(0.7);
metrics.add_quality_sample(0.9);
let stats = metrics.quality_stats();
assert!((stats.avg - 0.7).abs() < 0.01);
assert!((stats.min - 0.5).abs() < 0.01);
assert!((stats.max - 0.9).abs() < 0.01);
}
#[test]
fn test_training_result() {
let result = TrainingResult {
pipeline_name: "test".into(),
epochs_completed: 3,
total_examples: 1000,
patterns_learned: 50,
final_avg_quality: 0.85,
total_duration_secs: 10.0,
epoch_stats: vec![
EpochStats {
epoch: 0,
examples_processed: 333,
avg_quality: 0.75,
duration_secs: 3.0,
},
EpochStats {
epoch: 1,
examples_processed: 333,
avg_quality: 0.80,
duration_secs: 3.5,
},
EpochStats {
epoch: 2,
examples_processed: 334,
avg_quality: 0.85,
duration_secs: 3.5,
},
],
validation_quality: Some(0.82),
};
assert_eq!(result.examples_per_sec(), 100.0);
assert!(result.quality_improved());
assert!((result.quality_improvement() - 0.10).abs() < 0.01);
}
#[test]
fn test_training_comparison() {
let baseline = TrainingResult {
pipeline_name: "baseline".into(),
epochs_completed: 2,
total_examples: 500,
patterns_learned: 25,
final_avg_quality: 0.70,
total_duration_secs: 5.0,
epoch_stats: vec![],
validation_quality: None,
};
let improved = TrainingResult {
pipeline_name: "improved".into(),
epochs_completed: 2,
total_examples: 500,
patterns_learned: 30,
final_avg_quality: 0.85,
total_duration_secs: 4.0,
epoch_stats: vec![],
validation_quality: None,
};
let comparison = TrainingComparison::compare(&baseline, &improved);
assert!((comparison.quality_diff - 0.15).abs() < 0.01);
assert!(comparison.quality_improvement_pct > 20.0);
assert!(comparison.throughput_diff > 0.0);
}
}