1use core::cmp::Ordering;
2use std::{
3 collections::{HashMap, hash_map::Entry},
4 fmt::Display,
5 path::{Path, PathBuf},
6};
7
8use crate::{
9 logger::FileMetricLogger,
10 metric::store::{Aggregate, EventStore, LogEventStore, Split},
11};
12
13#[derive(Debug)]
15pub struct MetricEntry {
16 pub step: usize,
18 pub value: f64,
20}
21
22#[derive(Debug)]
24pub struct MetricSummary {
25 pub name: String,
27 pub entries: Vec<MetricEntry>,
29}
30
31impl MetricSummary {
32 fn collect<E: EventStore>(
33 event_store: &mut E,
34 metric: &str,
35 split: &Split,
36 num_epochs: usize,
37 ) -> Option<Self> {
38 let entries = (1..=num_epochs)
39 .filter_map(|epoch| {
40 event_store
41 .find_metric(metric, epoch, Aggregate::Mean, split)
42 .map(|value| MetricEntry { step: epoch, value })
43 })
44 .collect::<Vec<_>>();
45
46 if entries.is_empty() {
47 None
48 } else {
49 Some(Self {
50 name: metric.to_string(),
51 entries,
52 })
53 }
54 }
55}
56
57pub struct SummaryMetrics {
59 pub train: Vec<MetricSummary>,
61 pub valid: Vec<MetricSummary>,
63 pub test: HashMap<String, Vec<MetricSummary>>,
68}
69
70pub struct LearnerSummary {
72 pub epochs: usize,
74 pub metrics: SummaryMetrics,
76 pub(crate) model: Option<String>,
78}
79
80impl LearnerSummary {
81 pub fn new<S: AsRef<str>>(directory: impl AsRef<Path>, metrics: &[S]) -> Result<Self, String> {
88 let directory = directory.as_ref();
89 if !directory.exists() {
90 return Err(format!(
91 "Artifact directory does not exist at: {}",
92 directory.display()
93 ));
94 }
95
96 let mut event_store = LogEventStore::default();
97 let train_split = Split::Train;
98 let valid_split = Split::Valid;
99
100 let logger = FileMetricLogger::new(directory);
101 let test_split_root = logger.split_dir(&Split::Test(None));
102 if !logger.split_exists(&train_split)
103 && !logger.split_exists(&valid_split)
104 && test_split_root.is_none()
105 {
106 return Err(format!(
107 "No training, validation or test artifacts found at: {}",
108 directory.display()
109 ));
110 }
111
112 let epochs = logger.epochs();
114
115 event_store.register_logger(logger);
116
117 let train_summary = metrics
118 .iter()
119 .filter_map(|metric| {
120 MetricSummary::collect(&mut event_store, metric.as_ref(), &train_split, epochs)
121 })
122 .collect::<Vec<_>>();
123
124 let valid_summary = metrics
125 .iter()
126 .filter_map(|metric| {
127 MetricSummary::collect(&mut event_store, metric.as_ref(), &valid_split, epochs)
128 })
129 .collect::<Vec<_>>();
130
131 let test_summary = match test_split_root {
132 Some(root) => collect_test_split_metrics(root, metrics, &mut event_store, epochs),
133 None => Default::default(),
134 };
135
136 Ok(Self {
137 epochs,
138 metrics: SummaryMetrics {
139 train: train_summary,
140 valid: valid_summary,
141 test: test_summary,
142 },
143 model: None,
144 })
145 }
146
147 pub(crate) fn with_model(mut self, name: String) -> Self {
148 self.model = Some(name);
149 self
150 }
151
152 pub(crate) fn merge(mut self, other: LearnerSummary) -> Self {
154 fn merge_metrics(
155 base: Vec<MetricSummary>,
156 incoming: Vec<MetricSummary>,
157 ) -> Vec<MetricSummary> {
158 let mut map: HashMap<String, MetricSummary> =
159 base.into_iter().map(|m| (m.name.clone(), m)).collect();
160
161 for metric in incoming {
162 match map.entry(metric.name.clone()) {
163 Entry::Occupied(mut entry) => {
164 entry.get_mut().entries.extend(metric.entries);
165 }
166 Entry::Vacant(entry) => {
167 entry.insert(metric);
168 }
169 }
170 }
171 map.into_values().collect()
172 }
173
174 self.metrics.train = merge_metrics(self.metrics.train, other.metrics.train);
175 self.metrics.valid = merge_metrics(self.metrics.valid, other.metrics.valid);
176
177 for (tag, metrics) in other.metrics.test {
178 match self.metrics.test.entry(tag) {
179 Entry::Occupied(mut entry) => {
180 let current = std::mem::take(entry.get_mut());
181 let merged = merge_metrics(current, metrics);
182 *entry.get_mut() = merged;
183 }
184 Entry::Vacant(entry) => {
185 entry.insert(metrics);
186 }
187 }
188 }
189
190 if self.model != other.model {
191 self.model = None;
192 }
193
194 self
195 }
196}
197
198fn collect_test_split_metrics<P: AsRef<Path>, S: AsRef<str>>(
199 root: P,
200 metrics: &[S],
201 event_store: &mut LogEventStore,
202 epochs: usize,
203) -> HashMap<String, Vec<MetricSummary>> {
204 let dirs = match std::fs::read_dir(root) {
206 Ok(entries) => entries
207 .filter_map(|entry| {
208 let entry = entry.ok()?;
209 let file_type = entry.file_type().ok()?;
210 if file_type.is_dir() {
211 Some(entry.file_name().to_string_lossy().to_string())
212 } else {
213 None
214 }
215 })
216 .collect::<Vec<_>>(),
217 Err(_) => Vec::new(),
218 };
219
220 let mut map = HashMap::new();
221
222 if dirs.is_empty() {
223 return map;
224 }
225
226 let all_epochs = dirs.iter().all(FileMetricLogger::is_epoch_dir);
228
229 if all_epochs {
230 let split = Split::Test(None);
232
233 let summaries = metrics
234 .iter()
235 .filter_map(|metric| {
236 MetricSummary::collect(event_store, metric.as_ref(), &split, epochs)
237 })
238 .collect::<Vec<_>>();
239
240 map.insert("".to_string(), summaries);
242 } else {
243 for tag in dirs {
245 let split = Split::Test(Some(tag.clone().into()));
246
247 let summaries = metrics
248 .iter()
249 .filter_map(|metric| {
250 MetricSummary::collect(event_store, metric.as_ref(), &split, epochs)
251 })
252 .collect::<Vec<_>>();
253
254 map.insert(tag, summaries);
255 }
256 }
257
258 map
259}
260
261impl Display for LearnerSummary {
262 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
263 let mut max_split_len = 5; let mut max_metric_len = "Metric".len();
266 for metric in self.metrics.train.iter() {
267 max_metric_len = max_metric_len.max(metric.name.len());
268 }
269 for metric in self.metrics.valid.iter() {
270 max_metric_len = max_metric_len.max(metric.name.len());
271 }
272 for (tag, metrics) in self.metrics.test.iter() {
273 let split_name = if tag.is_empty() {
274 "Test".to_string()
275 } else {
276 format!("Test ({tag})")
277 };
278
279 max_split_len = max_split_len.max(split_name.len());
280
281 for metric in metrics {
282 max_metric_len = max_metric_len.max(metric.name.len());
283 }
284 }
285
286 writeln!(
288 f,
289 "{:=>width_symbol$} Learner Summary {:=>width_symbol$}",
290 "",
291 "",
292 width_symbol = 24,
293 )?;
294
295 if let Some(model) = &self.model {
296 writeln!(f, "Model:\n{model}")?;
297 }
298 writeln!(f, "Total Epochs: {epochs}\n\n", epochs = self.epochs)?;
299
300 writeln!(
302 f,
303 "| {:<width_split$} | {:<width_metric$} | Min. | Epoch | Max. | Epoch |\n|{:->width_split$}--|{:->width_metric$}--|----------|----------|----------|----------|",
304 "Split",
305 "Metric",
306 "",
307 "",
308 width_split = max_split_len,
309 width_metric = max_metric_len,
310 )?;
311
312 fn cmp_f64(a: &f64, b: &f64) -> Ordering {
314 match (a.is_nan(), b.is_nan()) {
315 (true, true) => Ordering::Equal,
316 (true, false) => Ordering::Greater,
317 (false, true) => Ordering::Less,
318 _ => a.partial_cmp(b).unwrap(),
319 }
320 }
321
322 fn fmt_val(val: f64) -> String {
323 if val < 1e-2 {
324 format!("{val:<9.3e}")
326 } else {
327 format!("{val:<9.3}")
328 }
329 }
330
331 let mut write_metrics_summary =
332 |metrics: &[MetricSummary], split: String| -> std::fmt::Result {
333 for metric in metrics.iter() {
334 if metric.entries.is_empty() {
335 continue; }
337
338 let metric_min = metric
340 .entries
341 .iter()
342 .min_by(|a, b| cmp_f64(&a.value, &b.value))
343 .unwrap();
344 let metric_max = metric
345 .entries
346 .iter()
347 .max_by(|a, b| cmp_f64(&a.value, &b.value))
348 .unwrap();
349
350 writeln!(
351 f,
352 "| {:<width_split$} | {:<width_metric$} | {}| {:<9?}| {}| {:<9?}|",
353 split,
354 metric.name,
355 fmt_val(metric_min.value),
356 metric_min.step,
357 fmt_val(metric_max.value),
358 metric_max.step,
359 width_split = max_split_len,
360 width_metric = max_metric_len,
361 )?;
362 }
363
364 Ok(())
365 };
366
367 write_metrics_summary(&self.metrics.train, format!("{:?}", Split::Train))?;
368 write_metrics_summary(&self.metrics.valid, format!("{:?}", Split::Valid))?;
369
370 for (tag, metrics) in &self.metrics.test {
371 let split_name = if tag.is_empty() {
372 "Test".to_string()
373 } else {
374 format!("Test ({tag})")
375 };
376
377 write_metrics_summary(metrics, split_name)?;
378 }
379
380 Ok(())
381 }
382}
383
384#[derive(Clone)]
387pub struct LearnerSummaryConfig {
389 pub(crate) directory: PathBuf,
390 pub(crate) metrics: Vec<String>,
391}
392
393impl LearnerSummaryConfig {
394 pub fn init(&self) -> Result<LearnerSummary, String> {
396 LearnerSummary::new(&self.directory, &self.metrics[..])
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
405 #[should_panic = "Summary artifacts should exist"]
406 fn test_artifact_dir_should_exist() {
407 let dir = "/tmp/learner-summary-not-found";
408 let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist");
409 }
410
411 #[test]
412 #[should_panic = "Summary artifacts should exist"]
413 fn test_train_valid_artifacts_should_exist() {
414 let dir = "/tmp/test-learner-summary-empty";
415 std::fs::create_dir_all(dir).ok();
416 let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist");
417 }
418
419 #[test]
420 fn test_summary_should_be_empty() {
421 let dir = Path::new("/tmp/test-learner-summary-empty-metrics");
422 std::fs::create_dir_all(dir).unwrap();
423 std::fs::create_dir_all(dir.join("train/epoch-1")).unwrap();
424 std::fs::create_dir_all(dir.join("valid/epoch-1")).unwrap();
425 let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"])
426 .expect("Summary artifacts should exist");
427
428 assert_eq!(summary.epochs, 1);
429
430 assert_eq!(summary.metrics.train.len(), 0);
431 assert_eq!(summary.metrics.valid.len(), 0);
432
433 std::fs::remove_dir_all(dir).unwrap();
434 }
435
436 #[test]
437 fn test_summary_should_be_collected() {
438 let dir = Path::new("/tmp/test-learner-summary");
439 let train_dir = dir.join("train/epoch-1");
440 let valid_dir = dir.join("valid/epoch-1");
441 std::fs::create_dir_all(dir).unwrap();
442 std::fs::create_dir_all(&train_dir).unwrap();
443 std::fs::create_dir_all(&valid_dir).unwrap();
444
445 std::fs::write(train_dir.join("Loss.log"), "1.0\n2.0").expect("Unable to write file");
446 std::fs::write(valid_dir.join("Loss.log"), "1.0").expect("Unable to write file");
447
448 let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"])
449 .expect("Summary artifacts should exist");
450
451 assert_eq!(summary.epochs, 1);
452
453 assert_eq!(summary.metrics.train.len(), 1);
455 assert_eq!(summary.metrics.valid.len(), 1);
456
457 let train_metric = &summary.metrics.train[0];
459 assert_eq!(train_metric.name, "Loss");
460 assert_eq!(train_metric.entries.len(), 1);
461 let entry = &train_metric.entries[0];
462 assert_eq!(entry.step, 1); assert_eq!(entry.value, 1.5); let valid_metric = &summary.metrics.valid[0];
467 assert_eq!(valid_metric.name, "Loss");
468 assert_eq!(valid_metric.entries.len(), 1);
469 let entry = &valid_metric.entries[0];
470 assert_eq!(entry.step, 1); assert_eq!(entry.value, 1.0);
472
473 std::fs::remove_dir_all(dir).unwrap();
474 }
475}