burn_train/learner/
application_logger.rs1use std::path::{Path, PathBuf};
2use tracing_core::{Level, LevelFilter};
3use tracing_subscriber::filter::filter_fn;
4use tracing_subscriber::prelude::*;
5use tracing_subscriber::{Layer, registry};
6
7pub trait ApplicationLoggerInstaller {
9 fn install(&self) -> Result<(), String>;
11}
12
13pub struct FileApplicationLoggerInstaller {
15 path: PathBuf,
16}
17
18impl FileApplicationLoggerInstaller {
19 pub fn new(path: impl AsRef<Path>) -> Self {
21 Self {
22 path: path.as_ref().to_path_buf(),
23 }
24 }
25}
26
27impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller {
28 fn install(&self) -> Result<(), String> {
29 let path = Path::new(&self.path);
30 let writer = tracing_appender::rolling::never(
31 path.parent().unwrap_or_else(|| Path::new(".")),
32 path.file_name().unwrap_or_else(|| {
33 panic!("The path '{}' to point to a file.", self.path.display())
34 }),
35 );
36 let layer = tracing_subscriber::fmt::layer()
37 .with_ansi(false)
38 .with_writer(writer)
39 .with_filter(LevelFilter::INFO)
40 .with_filter(filter_fn(|m| {
41 if let Some(path) = m.module_path() {
42 if path.starts_with("wgpu") && *m.level() >= Level::INFO {
44 return false;
45 }
46 }
47 true
48 }));
49
50 if registry().with(layer).try_init().is_err() {
51 return Err("Failed to install the file logger.".to_string());
52 }
53
54 let hook = std::panic::take_hook();
55 let file_path = self.path.to_owned();
56
57 std::panic::set_hook(Box::new(move |info| {
58 log::error!("PANIC => {info}");
59 eprintln!(
60 "=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \
61 '{}'\n=============",
62 file_path.display()
63 );
64 hook(info);
65 }));
66
67 Ok(())
68 }
69}