burn_train/learner/
application_logger.rs

1use std::path::{Path, PathBuf};
2use tracing_core::{Level, LevelFilter};
3use tracing_subscriber::filter::filter_fn;
4use tracing_subscriber::prelude::*;
5use tracing_subscriber::{Layer, registry};
6
7/// This trait is used to install an application logger.
8pub trait ApplicationLoggerInstaller {
9    /// Install the application logger.
10    fn install(&self) -> Result<(), String>;
11}
12
13/// This struct is used to install a local file application logger to output logs to a given file path.
14pub struct FileApplicationLoggerInstaller {
15    path: PathBuf,
16}
17
18impl FileApplicationLoggerInstaller {
19    /// Create a new file application logger.
20    pub fn new(path: impl AsRef<Path>) -> Self {
21        Self {
22            path: path.as_ref().to_path_buf(),
23        }
24    }
25}
26
27impl ApplicationLoggerInstaller for FileApplicationLoggerInstaller {
28    fn install(&self) -> Result<(), String> {
29        let path = Path::new(&self.path);
30        let writer = tracing_appender::rolling::never(
31            path.parent().unwrap_or_else(|| Path::new(".")),
32            path.file_name().unwrap_or_else(|| {
33                panic!("The path '{}' to point to a file.", self.path.display())
34            }),
35        );
36        let layer = tracing_subscriber::fmt::layer()
37            .with_ansi(false)
38            .with_writer(writer)
39            .with_filter(LevelFilter::INFO)
40            .with_filter(filter_fn(|m| {
41                if let Some(path) = m.module_path() {
42                    // The wgpu crate is logging too much, so we skip `info` level.
43                    if path.starts_with("wgpu") && *m.level() >= Level::INFO {
44                        return false;
45                    }
46                }
47                true
48            }));
49
50        if registry().with(layer).try_init().is_err() {
51            return Err("Failed to install the file logger.".to_string());
52        }
53
54        let hook = std::panic::take_hook();
55        let file_path = self.path.to_owned();
56
57        std::panic::set_hook(Box::new(move |info| {
58            log::error!("PANIC => {info}");
59            eprintln!(
60                "=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \
61                    '{}'\n=============",
62                file_path.display()
63            );
64            hook(info);
65        }));
66
67        Ok(())
68    }
69}