use alloc::{string::String, vec::Vec};
use core::fmt::Display;
use cubecl_common::{
config::{
RuntimeConfig,
logger::{LogLevel, LoggerConfig, LoggerSinks},
},
stub::Arc,
};
use super::{autodiff::AutodiffLogLevel, base::BurnConfig, fusion::FusionLogLevel};
static BURN_LOGGER: spin::Mutex<Option<Logger>> = spin::Mutex::new(None);
#[cfg(feature = "std")]
std::thread_local! {
static LOCAL_CONFIG: std::cell::OnceCell<Arc<BurnConfig>> =
const { std::cell::OnceCell::new() };
}
pub fn config() -> Arc<BurnConfig> {
#[cfg(feature = "std")]
{
LOCAL_CONFIG.with(|cell| cell.get_or_init(BurnConfig::get).clone())
}
#[cfg(not(feature = "std"))]
{
BurnConfig::get()
}
}
#[derive(Debug)]
pub struct Logger {
sinks: LoggerSinks,
fusion_index: Vec<usize>,
autodiff_index: Vec<usize>,
pub config: Arc<BurnConfig>,
}
impl Default for Logger {
fn default() -> Self {
Self::new()
}
}
impl Logger {
pub fn new() -> Self {
let config = BurnConfig::get();
let mut sinks = LoggerSinks::new();
let fusion_index = register_enabled(
&mut sinks,
&config.fusion.logger,
config.fusion.logger.level != FusionLogLevel::Disabled,
);
let autodiff_index = register_enabled(
&mut sinks,
&config.autodiff.logger,
config.autodiff.logger.level != AutodiffLogLevel::Disabled,
);
Self {
sinks,
fusion_index,
autodiff_index,
config,
}
}
pub fn log_fusion<S: Display>(&mut self, msg: &S) {
self.sinks.log(&self.fusion_index, msg);
}
pub fn log_autodiff<S: Display>(&mut self, msg: &S) {
self.sinks.log(&self.autodiff_index, msg);
}
pub fn log_level_fusion(&self) -> FusionLogLevel {
self.config.fusion.logger.level
}
pub fn log_level_autodiff(&self) -> AutodiffLogLevel {
self.config.autodiff.logger.level
}
}
fn register_enabled<L: LogLevel>(
sinks: &mut LoggerSinks,
config: &LoggerConfig<L>,
enabled: bool,
) -> Vec<usize> {
if enabled {
sinks.register(config)
} else {
Vec::new()
}
}
pub fn log_fusion<F>(level: FusionLogLevel, f: F)
where
F: FnOnce() -> String,
{
let current = config().fusion.logger.level;
if current < level {
return;
}
let msg = f();
let mut guard = BURN_LOGGER.lock();
let logger = guard.get_or_insert_with(Logger::new);
logger.log_fusion(&msg);
}
pub fn log_autodiff<F>(level: AutodiffLogLevel, f: F)
where
F: FnOnce() -> String,
{
let current = config().autodiff.logger.level;
if current < level {
return;
}
let msg = f();
let mut guard = BURN_LOGGER.lock();
let logger = guard.get_or_insert_with(Logger::new);
logger.log_autodiff(&msg);
}