burn_std/config/
logger.rs1use alloc::{string::String, vec::Vec};
2use core::fmt::Display;
3use cubecl_common::{
4 config::{
5 RuntimeConfig,
6 logger::{LogLevel, LoggerConfig, LoggerSinks},
7 },
8 stub::Arc,
9};
10
11use super::{autodiff::AutodiffLogLevel, base::BurnConfig, fusion::FusionLogLevel};
12
13static BURN_LOGGER: spin::Mutex<Option<Logger>> = spin::Mutex::new(None);
14
15#[cfg(feature = "std")]
16std::thread_local! {
17 static LOCAL_CONFIG: std::cell::OnceCell<Arc<BurnConfig>> =
18 const { std::cell::OnceCell::new() };
19}
20
21pub fn config() -> Arc<BurnConfig> {
31 #[cfg(feature = "std")]
32 {
33 LOCAL_CONFIG.with(|cell| cell.get_or_init(BurnConfig::get).clone())
34 }
35 #[cfg(not(feature = "std"))]
36 {
37 BurnConfig::get()
38 }
39}
40
41#[derive(Debug)]
43pub struct Logger {
44 sinks: LoggerSinks,
45 fusion_index: Vec<usize>,
46 autodiff_index: Vec<usize>,
47 pub config: Arc<BurnConfig>,
49}
50
51impl Default for Logger {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl Logger {
58 pub fn new() -> Self {
63 let config = BurnConfig::get();
64 let mut sinks = LoggerSinks::new();
65
66 let fusion_index = register_enabled(
67 &mut sinks,
68 &config.fusion.logger,
69 config.fusion.logger.level != FusionLogLevel::Disabled,
70 );
71 let autodiff_index = register_enabled(
72 &mut sinks,
73 &config.autodiff.logger,
74 config.autodiff.logger.level != AutodiffLogLevel::Disabled,
75 );
76
77 Self {
78 sinks,
79 fusion_index,
80 autodiff_index,
81 config,
82 }
83 }
84
85 pub fn log_fusion<S: Display>(&mut self, msg: &S) {
87 self.sinks.log(&self.fusion_index, msg);
88 }
89
90 pub fn log_autodiff<S: Display>(&mut self, msg: &S) {
92 self.sinks.log(&self.autodiff_index, msg);
93 }
94
95 pub fn log_level_fusion(&self) -> FusionLogLevel {
97 self.config.fusion.logger.level
98 }
99
100 pub fn log_level_autodiff(&self) -> AutodiffLogLevel {
102 self.config.autodiff.logger.level
103 }
104}
105
106fn register_enabled<L: LogLevel>(
107 sinks: &mut LoggerSinks,
108 config: &LoggerConfig<L>,
109 enabled: bool,
110) -> Vec<usize> {
111 if enabled {
112 sinks.register(config)
113 } else {
114 Vec::new()
115 }
116}
117
118pub fn log_fusion<F>(level: FusionLogLevel, f: F)
122where
123 F: FnOnce() -> String,
124{
125 let current = config().fusion.logger.level;
126 if current < level {
127 return;
128 }
129 let msg = f();
130 let mut guard = BURN_LOGGER.lock();
131 let logger = guard.get_or_insert_with(Logger::new);
132 logger.log_fusion(&msg);
133}
134
135pub fn log_autodiff<F>(level: AutodiffLogLevel, f: F)
139where
140 F: FnOnce() -> String,
141{
142 let current = config().autodiff.logger.level;
143 if current < level {
144 return;
145 }
146 let msg = f();
147 let mut guard = BURN_LOGGER.lock();
148 let logger = guard.get_or_insert_with(Logger::new);
149 logger.log_autodiff(&msg);
150}