1use core::cmp::Ordering;
2use std::{
3 fmt::Display,
4 path::{Path, PathBuf},
5};
6
7use crate::{
8 logger::FileMetricLogger,
9 metric::store::{Aggregate, EventStore, LogEventStore, Split},
10};
11
12pub struct MetricEntry {
14 pub step: usize,
16 pub value: f64,
18}
19
20pub struct MetricSummary {
22 pub name: String,
24 pub entries: Vec<MetricEntry>,
26}
27
28impl MetricSummary {
29 fn new<E: EventStore>(
30 event_store: &mut E,
31 metric: &str,
32 split: Split,
33 num_epochs: usize,
34 ) -> Option<Self> {
35 let entries = (1..=num_epochs)
36 .filter_map(|epoch| {
37 event_store
38 .find_metric(metric, epoch, Aggregate::Mean, split)
39 .map(|value| MetricEntry { step: epoch, value })
40 })
41 .collect::<Vec<_>>();
42
43 if entries.is_empty() {
44 None
45 } else {
46 Some(Self {
47 name: metric.to_string(),
48 entries,
49 })
50 }
51 }
52}
53
54pub struct SummaryMetrics {
56 pub train: Vec<MetricSummary>,
58 pub valid: Vec<MetricSummary>,
60}
61
62pub struct LearnerSummary {
64 pub epochs: usize,
66 pub metrics: SummaryMetrics,
68 pub(crate) model: Option<String>,
70}
71
72impl LearnerSummary {
73 pub fn new<S: AsRef<str>>(directory: impl AsRef<Path>, metrics: &[S]) -> Result<Self, String> {
80 let directory = directory.as_ref();
81 if !directory.exists() {
82 return Err(format!(
83 "Artifact directory does not exist at: {}",
84 directory.display()
85 ));
86 }
87 let train_dir = directory.join("train");
88 let valid_dir = directory.join("valid");
89 if !train_dir.exists() & !valid_dir.exists() {
90 return Err(format!(
91 "No training or validation artifacts found at: {}",
92 directory.display()
93 ));
94 }
95
96 let mut event_store = LogEventStore::default();
97
98 let train_logger = FileMetricLogger::new_train(train_dir.to_str().unwrap());
99 let valid_logger = FileMetricLogger::new_train(valid_dir.to_str().unwrap());
100
101 let epochs = train_logger.epochs();
103
104 event_store.register_logger_train(train_logger);
105 event_store.register_logger_valid(valid_logger);
106
107 let train_summary = metrics
108 .iter()
109 .filter_map(|metric| {
110 MetricSummary::new(&mut event_store, metric.as_ref(), Split::Train, epochs)
111 })
112 .collect::<Vec<_>>();
113
114 let valid_summary = metrics
115 .iter()
116 .filter_map(|metric| {
117 MetricSummary::new(&mut event_store, metric.as_ref(), Split::Valid, epochs)
118 })
119 .collect::<Vec<_>>();
120
121 Ok(Self {
122 epochs,
123 metrics: SummaryMetrics {
124 train: train_summary,
125 valid: valid_summary,
126 },
127 model: None,
128 })
129 }
130
131 pub(crate) fn with_model(mut self, name: String) -> Self {
132 self.model = Some(name);
133 self
134 }
135}
136
137impl Display for LearnerSummary {
138 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139 let split_train = "Train";
141 let split_valid = "Valid";
142 let max_split_len = "Split".len().max(split_train.len()).max(split_valid.len());
143 let mut max_metric_len = "Metric".len();
144 for metric in self.metrics.train.iter() {
145 max_metric_len = max_metric_len.max(metric.name.len());
146 }
147 for metric in self.metrics.valid.iter() {
148 max_metric_len = max_metric_len.max(metric.name.len());
149 }
150
151 writeln!(
153 f,
154 "{:=>width_symbol$} Learner Summary {:=>width_symbol$}",
155 "",
156 "",
157 width_symbol = 24,
158 )?;
159
160 if let Some(model) = &self.model {
161 writeln!(f, "Model:\n{model}")?;
162 }
163 writeln!(f, "Total Epochs: {epochs}\n\n", epochs = self.epochs)?;
164
165 writeln!(
167 f,
168 "| {:<width_split$} | {:<width_metric$} | Min. | Epoch | Max. | Epoch |\n|{:->width_split$}--|{:->width_metric$}--|----------|----------|----------|----------|",
169 "Split",
170 "Metric",
171 "",
172 "",
173 width_split = max_split_len,
174 width_metric = max_metric_len,
175 )?;
176
177 fn cmp_f64(a: &f64, b: &f64) -> Ordering {
179 match (a.is_nan(), b.is_nan()) {
180 (true, true) => Ordering::Equal,
181 (true, false) => Ordering::Greater,
182 (false, true) => Ordering::Less,
183 _ => a.partial_cmp(b).unwrap(),
184 }
185 }
186
187 fn fmt_val(val: f64) -> String {
188 if val < 1e-2 {
189 format!("{val:<9.3e}")
191 } else {
192 format!("{val:<9.3}")
193 }
194 }
195
196 let mut write_metrics_summary =
197 |metrics: &[MetricSummary], split: &str| -> std::fmt::Result {
198 for metric in metrics.iter() {
199 if metric.entries.is_empty() {
200 continue; }
202
203 let metric_min = metric
205 .entries
206 .iter()
207 .min_by(|a, b| cmp_f64(&a.value, &b.value))
208 .unwrap();
209 let metric_max = metric
210 .entries
211 .iter()
212 .max_by(|a, b| cmp_f64(&a.value, &b.value))
213 .unwrap();
214
215 writeln!(
216 f,
217 "| {:<width_split$} | {:<width_metric$} | {}| {:<9?}| {}| {:<9?}|",
218 split,
219 metric.name,
220 fmt_val(metric_min.value),
221 metric_min.step,
222 fmt_val(metric_max.value),
223 metric_max.step,
224 width_split = max_split_len,
225 width_metric = max_metric_len,
226 )?;
227 }
228
229 Ok(())
230 };
231
232 write_metrics_summary(&self.metrics.train, split_train)?;
233 write_metrics_summary(&self.metrics.valid, split_valid)?;
234
235 Ok(())
236 }
237}
238
239pub(crate) struct LearnerSummaryConfig {
240 pub(crate) directory: PathBuf,
241 pub(crate) metrics: Vec<String>,
242}
243
244impl LearnerSummaryConfig {
245 pub fn init(&self) -> Result<LearnerSummary, String> {
246 LearnerSummary::new(&self.directory, &self.metrics[..])
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[test]
255 #[should_panic = "Summary artifacts should exist"]
256 fn test_artifact_dir_should_exist() {
257 let dir = "/tmp/learner-summary-not-found";
258 let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist");
259 }
260
261 #[test]
262 #[should_panic = "Summary artifacts should exist"]
263 fn test_train_valid_artifacts_should_exist() {
264 let dir = "/tmp/test-learner-summary-empty";
265 std::fs::create_dir_all(dir).ok();
266 let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist");
267 }
268
269 #[test]
270 fn test_summary_should_be_empty() {
271 let dir = Path::new("/tmp/test-learner-summary-empty-metrics");
272 std::fs::create_dir_all(dir).unwrap();
273 std::fs::create_dir_all(dir.join("train/epoch-1")).unwrap();
274 std::fs::create_dir_all(dir.join("valid/epoch-1")).unwrap();
275 let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"])
276 .expect("Summary artifacts should exist");
277
278 assert_eq!(summary.epochs, 1);
279
280 assert_eq!(summary.metrics.train.len(), 0);
281 assert_eq!(summary.metrics.valid.len(), 0);
282
283 std::fs::remove_dir_all(dir).unwrap();
284 }
285
286 #[test]
287 fn test_summary_should_be_collected() {
288 let dir = Path::new("/tmp/test-learner-summary");
289 let train_dir = dir.join("train/epoch-1");
290 let valid_dir = dir.join("valid/epoch-1");
291 std::fs::create_dir_all(dir).unwrap();
292 std::fs::create_dir_all(&train_dir).unwrap();
293 std::fs::create_dir_all(&valid_dir).unwrap();
294
295 std::fs::write(train_dir.join("Loss.log"), "1.0\n2.0").expect("Unable to write file");
296 std::fs::write(valid_dir.join("Loss.log"), "1.0").expect("Unable to write file");
297
298 let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"])
299 .expect("Summary artifacts should exist");
300
301 assert_eq!(summary.epochs, 1);
302
303 assert_eq!(summary.metrics.train.len(), 1);
305 assert_eq!(summary.metrics.valid.len(), 1);
306
307 let train_metric = &summary.metrics.train[0];
309 assert_eq!(train_metric.name, "Loss");
310 assert_eq!(train_metric.entries.len(), 1);
311 let entry = &train_metric.entries[0];
312 assert_eq!(entry.step, 1); assert_eq!(entry.value, 1.5); let valid_metric = &summary.metrics.valid[0];
317 assert_eq!(valid_metric.name, "Loss");
318 assert_eq!(valid_metric.entries.len(), 1);
319 let entry = &valid_metric.entries[0];
320 assert_eq!(entry.step, 1); assert_eq!(entry.value, 1.0);
322
323 std::fs::remove_dir_all(dir).unwrap();
324 }
325}