1use std::fs::File;
32use std::io::{BufWriter, Write};
33use std::path::Path;
34
35use anyhow::{Context, Result};
36use serde::{Deserialize, Serialize};
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct TrainingMetrics {
44 pub iteration: u32,
46 pub loss_total: f32,
48 pub loss_l1: f32,
50 pub loss_ssim: f32,
52 pub loss_lpips: Option<f32>,
54 pub loss_reg: f32,
56 pub num_gaussians: u32,
58 pub lr_position: f32,
60 pub lr_scaling: f32,
62 pub lr_rotation: f32,
64 pub memory_mb: Option<u64>,
66 pub elapsed_seconds: f32,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum MetricsFormat {
73 Csv,
75 JsonLines,
77}
78
79pub struct MetricsWriter {
84 format: MetricsFormat,
85 writer: BufWriter<File>,
86}
87
88impl MetricsWriter {
89 pub fn new(path: &Path, format: MetricsFormat) -> Result<Self> {
97 let file = File::create(path)
98 .with_context(|| format!("Failed to create metrics file: {}", path.display()))?;
99
100 let mut writer = BufWriter::new(file);
101
102 if matches!(format, MetricsFormat::Csv) {
104 writeln!(
105 writer,
106 "iteration,loss_total,loss_l1,loss_ssim,loss_lpips,loss_reg,num_gaussians,lr_position,lr_scaling,lr_rotation,memory_mb,elapsed_seconds"
107 )
108 .context("Failed to write CSV header")?;
109 }
110
111 Ok(Self { format, writer })
112 }
113
114 pub fn write_metrics(&mut self, metrics: &TrainingMetrics) -> Result<()> {
122 match self.format {
123 MetricsFormat::Csv => {
124 writeln!(
125 self.writer,
126 "{},{},{},{},{},{},{},{},{},{},{},{}",
127 metrics.iteration,
128 metrics.loss_total,
129 metrics.loss_l1,
130 metrics.loss_ssim,
131 metrics.loss_lpips.unwrap_or(0.0),
132 metrics.loss_reg,
133 metrics.num_gaussians,
134 metrics.lr_position,
135 metrics.lr_scaling,
136 metrics.lr_rotation,
137 metrics.memory_mb.unwrap_or(0),
138 metrics.elapsed_seconds,
139 )
140 .context("Failed to write CSV metrics line")?;
141 }
142 MetricsFormat::JsonLines => {
143 let json = serde_json::to_string(metrics)
144 .context("Failed to serialize metrics to JSON")?;
145 writeln!(self.writer, "{}", json).context("Failed to write JSON Lines metrics")?;
146 }
147 }
148
149 self.writer
151 .flush()
152 .context("Failed to flush metrics writer")?;
153 Ok(())
154 }
155
156 #[allow(dead_code)]
164 pub fn flush(&mut self) -> Result<()> {
165 self.writer
166 .flush()
167 .context("Failed to flush metrics writer")?;
168 Ok(())
169 }
170}
171
172impl Drop for MetricsWriter {
173 fn drop(&mut self) {
174 let _ = self.writer.flush();
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use std::env;
183 use std::fs;
184
185 #[test]
186 #[allow(clippy::expect_used)]
187 fn test_csv_format_header() {
188 let temp_dir = env::temp_dir();
189 let path = temp_dir.join("oxigaf_metrics_test_csv_header.csv");
190 let _ = fs::remove_file(&path); let writer = MetricsWriter::new(&path, MetricsFormat::Csv);
193 assert!(writer.is_ok(), "Failed to create CSV writer");
194
195 drop(writer);
196
197 let content = fs::read_to_string(&path).expect("Failed to read metrics file");
198 assert!(
199 content.contains("iteration,loss_total"),
200 "CSV header not found"
201 );
202
203 let _ = fs::remove_file(&path);
205 }
206
207 #[test]
208 #[allow(clippy::expect_used)]
209 fn test_csv_format_data() {
210 let temp_dir = env::temp_dir();
211 let path = temp_dir.join("oxigaf_metrics_test_csv_data.csv");
212 let _ = fs::remove_file(&path);
213
214 let mut writer =
215 MetricsWriter::new(&path, MetricsFormat::Csv).expect("Failed to create CSV writer");
216
217 let metrics = TrainingMetrics {
218 iteration: 1,
219 loss_total: 1.234,
220 loss_l1: 0.5,
221 loss_ssim: 0.3,
222 loss_lpips: Some(0.1),
223 loss_reg: 0.334,
224 num_gaussians: 50000,
225 lr_position: 0.00016,
226 lr_scaling: 0.005,
227 lr_rotation: 0.001,
228 memory_mb: Some(4096),
229 elapsed_seconds: 120.5,
230 };
231
232 writer
233 .write_metrics(&metrics)
234 .expect("Failed to write metrics");
235 drop(writer);
236
237 let content = fs::read_to_string(&path).expect("Failed to read metrics file");
238 assert!(content.contains("1,1.234"), "CSV data not found");
239 assert!(content.contains("50000"), "Gaussian count not found");
240
241 let _ = fs::remove_file(&path);
243 }
244
245 #[test]
246 #[allow(clippy::expect_used)]
247 fn test_json_lines_format() {
248 let temp_dir = env::temp_dir();
249 let path = temp_dir.join("oxigaf_metrics_test_jsonl.jsonl");
250 let _ = fs::remove_file(&path);
251
252 let mut writer = MetricsWriter::new(&path, MetricsFormat::JsonLines)
253 .expect("Failed to create JSON Lines writer");
254
255 let metrics = TrainingMetrics {
256 iteration: 1,
257 loss_total: 1.234,
258 loss_l1: 0.5,
259 loss_ssim: 0.3,
260 loss_lpips: Some(0.1),
261 loss_reg: 0.334,
262 num_gaussians: 50000,
263 lr_position: 0.00016,
264 lr_scaling: 0.005,
265 lr_rotation: 0.001,
266 memory_mb: Some(4096),
267 elapsed_seconds: 120.5,
268 };
269
270 writer
271 .write_metrics(&metrics)
272 .expect("Failed to write metrics");
273 drop(writer);
274
275 let content = fs::read_to_string(&path).expect("Failed to read metrics file");
276 let parsed: Result<serde_json::Value, _> =
277 serde_json::from_str(content.lines().next().expect("No lines in file"));
278 assert!(parsed.is_ok(), "Failed to parse JSON");
279
280 let json = parsed.expect("JSON parsing failed");
281 assert_eq!(json["iteration"], 1);
282 assert_eq!(json["num_gaussians"], 50000);
283
284 let _ = fs::remove_file(&path);
286 }
287
288 #[test]
289 #[allow(clippy::expect_used)]
290 fn test_multiple_writes() {
291 let temp_dir = env::temp_dir();
292 let path = temp_dir.join("oxigaf_metrics_test_multiple.csv");
293 let _ = fs::remove_file(&path);
294
295 let mut writer =
296 MetricsWriter::new(&path, MetricsFormat::Csv).expect("Failed to create CSV writer");
297
298 for i in 0..5 {
299 let metrics = TrainingMetrics {
300 iteration: i,
301 loss_total: 1.0 - (i as f32 * 0.1),
302 loss_l1: 0.5,
303 loss_ssim: 0.3,
304 loss_lpips: None,
305 loss_reg: 0.2,
306 num_gaussians: 50000 + i * 100,
307 lr_position: 0.00016,
308 lr_scaling: 0.005,
309 lr_rotation: 0.001,
310 memory_mb: None,
311 elapsed_seconds: i as f32 * 10.0,
312 };
313
314 writer
315 .write_metrics(&metrics)
316 .expect("Failed to write metrics");
317 }
318
319 drop(writer);
320
321 let content = fs::read_to_string(&path).expect("Failed to read metrics file");
322 let lines: Vec<&str> = content.lines().collect();
323 assert_eq!(lines.len(), 6, "Expected header + 5 data lines"); let _ = fs::remove_file(&path);
327 }
328
329 #[test]
330 #[allow(clippy::expect_used)]
331 fn test_manual_flush() {
332 let temp_dir = env::temp_dir();
333 let path = temp_dir.join("oxigaf_metrics_test_flush.csv");
334 let _ = fs::remove_file(&path);
335
336 let mut writer =
337 MetricsWriter::new(&path, MetricsFormat::Csv).expect("Failed to create CSV writer");
338
339 let metrics = TrainingMetrics {
340 iteration: 1,
341 loss_total: 1.234,
342 loss_l1: 0.5,
343 loss_ssim: 0.3,
344 loss_lpips: Some(0.1),
345 loss_reg: 0.334,
346 num_gaussians: 50000,
347 lr_position: 0.00016,
348 lr_scaling: 0.005,
349 lr_rotation: 0.001,
350 memory_mb: Some(4096),
351 elapsed_seconds: 120.5,
352 };
353
354 writer
355 .write_metrics(&metrics)
356 .expect("Failed to write metrics");
357
358 assert!(writer.flush().is_ok(), "Manual flush failed");
360
361 drop(writer);
362
363 let _ = fs::remove_file(&path);
365 }
366}