Skip to main content

burn_std/config/
logger.rs

1use alloc::{string::String, vec::Vec};
2use core::fmt::Display;
3use cubecl_common::{
4    config::{
5        RuntimeConfig,
6        logger::{LogLevel, LoggerConfig, LoggerSinks},
7    },
8    stub::Arc,
9};
10
11use super::{autodiff::AutodiffLogLevel, base::BurnConfig, fusion::FusionLogLevel};
12
13static BURN_LOGGER: spin::Mutex<Option<Logger>> = spin::Mutex::new(None);
14
15#[cfg(feature = "std")]
16std::thread_local! {
17    static LOCAL_CONFIG: std::cell::OnceCell<Arc<BurnConfig>> =
18        const { std::cell::OnceCell::new() };
19}
20
21/// Returns the current [`BurnConfig`], cached in thread-local storage on native targets.
22///
23/// On the first call from a given thread this fetches the global config via
24/// [`BurnConfig::get`] (which locks a spin mutex) and caches the `Arc` thread-locally.
25/// Subsequent calls on the same thread only pay an `Arc` clone. On `no_std` builds this
26/// is equivalent to [`BurnConfig::get`].
27///
28/// Safe because [`BurnConfig::set`] panics after the first read, so the cached snapshot
29/// matches the global singleton for the whole program lifetime.
30pub fn config() -> Arc<BurnConfig> {
31    #[cfg(feature = "std")]
32    {
33        LOCAL_CONFIG.with(|cell| cell.get_or_init(BurnConfig::get).clone())
34    }
35    #[cfg(not(feature = "std"))]
36    {
37        BurnConfig::get()
38    }
39}
40
41/// Central logging utility for Burn, managing one sink registry shared across subsystems.
42#[derive(Debug)]
43pub struct Logger {
44    sinks: LoggerSinks,
45    fusion_index: Vec<usize>,
46    autodiff_index: Vec<usize>,
47    /// The configuration snapshot the logger was initialized with.
48    pub config: Arc<BurnConfig>,
49}
50
51impl Default for Logger {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl Logger {
58    /// Creates a new `Logger` from the current global `BurnConfig`.
59    ///
60    /// Note that creating a logger is somewhat expensive because it opens file handles for any
61    /// sink configured with a file path.
62    pub fn new() -> Self {
63        let config = BurnConfig::get();
64        let mut sinks = LoggerSinks::new();
65
66        let fusion_index = register_enabled(
67            &mut sinks,
68            &config.fusion.logger,
69            config.fusion.logger.level != FusionLogLevel::Disabled,
70        );
71        let autodiff_index = register_enabled(
72            &mut sinks,
73            &config.autodiff.logger,
74            config.autodiff.logger.level != AutodiffLogLevel::Disabled,
75        );
76
77        Self {
78            sinks,
79            fusion_index,
80            autodiff_index,
81            config,
82        }
83    }
84
85    /// Writes `msg` to all configured fusion sinks.
86    pub fn log_fusion<S: Display>(&mut self, msg: &S) {
87        self.sinks.log(&self.fusion_index, msg);
88    }
89
90    /// Writes `msg` to all configured autodiff sinks.
91    pub fn log_autodiff<S: Display>(&mut self, msg: &S) {
92        self.sinks.log(&self.autodiff_index, msg);
93    }
94
95    /// Returns the current fusion log level.
96    pub fn log_level_fusion(&self) -> FusionLogLevel {
97        self.config.fusion.logger.level
98    }
99
100    /// Returns the current autodiff log level.
101    pub fn log_level_autodiff(&self) -> AutodiffLogLevel {
102        self.config.autodiff.logger.level
103    }
104}
105
106fn register_enabled<L: LogLevel>(
107    sinks: &mut LoggerSinks,
108    config: &LoggerConfig<L>,
109    enabled: bool,
110) -> Vec<usize> {
111    if enabled {
112        sinks.register(config)
113    } else {
114        Vec::new()
115    }
116}
117
118/// Emit a fusion log message when the configured level is at least `level`.
119///
120/// The message is only constructed when logging is enabled.
121pub fn log_fusion<F>(level: FusionLogLevel, f: F)
122where
123    F: FnOnce() -> String,
124{
125    let current = config().fusion.logger.level;
126    if current < level {
127        return;
128    }
129    let msg = f();
130    let mut guard = BURN_LOGGER.lock();
131    let logger = guard.get_or_insert_with(Logger::new);
132    logger.log_fusion(&msg);
133}
134
135/// Emit an autodiff log message when the configured level is at least `level`.
136///
137/// The message is only constructed when logging is enabled.
138pub fn log_autodiff<F>(level: AutodiffLogLevel, f: F)
139where
140    F: FnOnce() -> String,
141{
142    let current = config().autodiff.logger.level;
143    if current < level {
144        return;
145    }
146    let msg = f();
147    let mut guard = BURN_LOGGER.lock();
148    let logger = guard.get_or_insert_with(Logger::new);
149    logger.log_autodiff(&msg);
150}