burn_train/logger/
metric.rs1use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger};
2use crate::metric::{MetricEntry, NumericEntry};
3use std::{
4 collections::HashMap,
5 fs,
6 path::{Path, PathBuf},
7};
8
9const EPOCH_PREFIX: &str = "epoch-";
10
11pub trait MetricLogger: Send {
13 fn log(&mut self, item: &MetricEntry);
19
20 fn end_epoch(&mut self, epoch: usize);
26
27 fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<NumericEntry>, String>;
29}
30
31pub struct FileMetricLogger {
33 loggers: HashMap<String, AsyncLogger<String>>,
34 directory: PathBuf,
35 epoch: Option<usize>,
36}
37
38impl FileMetricLogger {
39 pub fn new_train(directory: impl AsRef<Path>) -> Self {
49 Self {
50 loggers: HashMap::new(),
51 directory: directory.as_ref().to_path_buf(),
52 epoch: Some(1),
53 }
54 }
55
56 pub fn new_eval(directory: impl AsRef<Path>) -> Self {
66 Self {
67 loggers: HashMap::new(),
68 directory: directory.as_ref().to_path_buf(),
69 epoch: None,
70 }
71 }
72
73 pub(crate) fn epochs(&self) -> usize {
75 if self.epoch.is_none() {
76 log::warn!("Number of epochs not available when testing.");
77 return 0;
78 }
79
80 let mut max_epoch = 0;
81
82 for path in fs::read_dir(&self.directory).unwrap() {
83 let path = path.unwrap();
84
85 if fs::metadata(path.path()).unwrap().is_dir() {
86 let dir_name = path.file_name().into_string().unwrap();
87
88 if !dir_name.starts_with(EPOCH_PREFIX) {
89 continue;
90 }
91
92 let epoch = dir_name.replace(EPOCH_PREFIX, "").parse::<usize>().ok();
93
94 if let Some(epoch) = epoch
95 && epoch > max_epoch
96 {
97 max_epoch = epoch;
98 }
99 }
100 }
101
102 max_epoch
103 }
104
105 fn train_directory(&self, tags: Option<&String>, epoch: usize) -> PathBuf {
106 let name = format!("{EPOCH_PREFIX}{epoch}");
107
108 match tags {
109 Some(tags) => self.directory.join(tags).join(name),
110 None => self.directory.join(name),
111 }
112 }
113
114 fn eval_directory(&self, tags: Option<&String>) -> PathBuf {
115 match tags {
116 Some(tags) => self.directory.join(tags),
117 None => self.directory.clone(),
118 }
119 }
120
121 fn file_path(&self, tags: Option<&String>, name: &str, epoch: Option<usize>) -> PathBuf {
122 let directory = match epoch {
123 Some(epoch) => self.train_directory(tags, epoch),
124 None => self.eval_directory(tags),
125 };
126 let name = name.replace(' ', "_");
127 let name = format!("{name}.log");
128 directory.join(name)
129 }
130
131 fn create_directory(&self, tags: Option<&String>, epoch: Option<usize>) {
132 let directory = match epoch {
133 Some(epoch) => self.train_directory(tags, epoch),
134 None => self.eval_directory(tags),
135 };
136 std::fs::create_dir_all(directory).ok();
137 }
138}
139
140impl FileMetricLogger {
141 fn log_item(&mut self, tags: Option<&String>, item: &MetricEntry) {
142 let key = &item.name;
143 let value = &item.serialize;
144
145 let logger = match self.loggers.get_mut(key.as_ref()) {
146 Some(val) => val,
147 None => {
148 self.create_directory(tags, self.epoch);
149
150 let file_path = self.file_path(tags, key, self.epoch);
151 let logger = FileLogger::new(file_path);
152 let logger = AsyncLogger::new(logger);
153
154 self.loggers.insert(key.to_string(), logger);
155 self.loggers
156 .get_mut(key.as_ref())
157 .expect("Can get the previously saved logger.")
158 }
159 };
160
161 logger.log(value.clone());
162 }
163
164 fn log_tags(&mut self, item: &MetricEntry) {
165 let mut tags = String::new();
166 item.tags.iter().for_each(|tag| tags += tag.as_str());
167 let tags = tags.replace(" ", "-").trim().to_lowercase();
168 self.log_item(Some(&tags), item);
169 }
170}
171
172impl MetricLogger for FileMetricLogger {
173 fn log(&mut self, item: &MetricEntry) {
174 match item.tags.is_empty() {
175 true => self.log_item(None, item),
176 false => self.log_tags(item),
177 }
178 }
179
180 fn end_epoch(&mut self, epoch: usize) {
181 self.loggers.clear();
182 if self.epoch.is_none() {
183 panic!("Only evaluation logger supported.");
184 }
185 self.epoch = Some(epoch + 1);
186 }
187
188 fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<NumericEntry>, String> {
189 if let Some(value) = self.loggers.get(name) {
190 value.sync()
191 }
192
193 let file_path = self.file_path(None, name, Some(epoch));
194
195 let mut errors = false;
196
197 let data = std::fs::read_to_string(file_path)
198 .unwrap_or_default()
199 .split('\n')
200 .filter_map(|value| {
201 if value.is_empty() {
202 None
203 } else {
204 match NumericEntry::deserialize(value) {
205 Ok(value) => Some(value),
206 Err(err) => {
207 log::error!("{err}");
208 errors = true;
209 None
210 }
211 }
212 }
213 })
214 .collect();
215
216 if errors {
217 Err("Parsing numeric entry errors".to_string())
218 } else {
219 Ok(data)
220 }
221 }
222}
223
224#[derive(Default)]
226pub struct InMemoryMetricLogger {
227 values: HashMap<String, Vec<InMemoryLogger>>,
228}
229
230impl InMemoryMetricLogger {
231 pub fn new() -> Self {
233 Self::default()
234 }
235}
236impl MetricLogger for InMemoryMetricLogger {
237 fn log(&mut self, item: &MetricEntry) {
238 if !self.values.contains_key(item.name.as_ref()) {
239 self.values
240 .insert(item.name.to_string(), vec![InMemoryLogger::default()]);
241 }
242
243 let values = self.values.get_mut(item.name.as_ref()).unwrap();
244
245 values.last_mut().unwrap().log(item.serialize.clone());
246 }
247
248 fn end_epoch(&mut self, _epoch: usize) {
249 for (_, values) in self.values.iter_mut() {
250 values.push(InMemoryLogger::default());
251 }
252 }
253
254 fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<NumericEntry>, String> {
255 let values = match self.values.get(name) {
256 Some(values) => values,
257 None => return Ok(Vec::new()),
258 };
259
260 match values.get(epoch - 1) {
261 Some(logger) => Ok(logger
262 .values
263 .iter()
264 .filter_map(|value| NumericEntry::deserialize(value).ok())
265 .collect()),
266 None => Ok(Vec::new()),
267 }
268 }
269}