1use cubecl_common::config::RuntimeConfig;
2use cubecl_common::stub::Arc;
3
4use super::autodiff::AutodiffConfig;
5use super::fusion::FusionConfig;
6
7static BURN_GLOBAL_CONFIG: spin::Mutex<Option<Arc<BurnConfig>>> = spin::Mutex::new(None);
9
10#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
12pub struct BurnConfig {
13 #[serde(default)]
15 pub fusion: FusionConfig,
16
17 #[serde(default)]
19 pub autodiff: AutodiffConfig,
20}
21
22impl RuntimeConfig for BurnConfig {
23 fn storage() -> &'static spin::Mutex<Option<Arc<Self>>> {
24 &BURN_GLOBAL_CONFIG
25 }
26
27 fn file_names() -> &'static [&'static str] {
28 &["burn.toml", "Burn.toml"]
29 }
30
31 #[cfg(all(
34 feature = "std",
35 any(
36 target_os = "windows",
37 target_os = "linux",
38 target_os = "macos",
39 target_os = "android"
40 )
41 ))]
42 fn override_from_env(mut self) -> Self {
43 use super::fusion::FusionLogLevel;
44
45 if let Ok(val) = std::env::var("BURN_FUSION_LOG") {
46 let level = match val.to_ascii_lowercase().as_str() {
47 "disabled" | "off" | "0" => FusionLogLevel::Disabled,
48 "basic" => FusionLogLevel::Basic,
49 "medium" => FusionLogLevel::Medium,
50 "full" | "1" => FusionLogLevel::Full,
51 _ => self.fusion.logger.level,
52 };
53 self.fusion.logger.level = level;
54 if level != FusionLogLevel::Disabled {
56 self.fusion.logger.stderr = true;
57 }
58 }
59
60 self
61 }
62}