Skip to main content

burn_std/config/
base.rs

1use cubecl_common::config::RuntimeConfig;
2use cubecl_common::stub::Arc;
3
4use super::autodiff::AutodiffConfig;
5use super::fusion::FusionConfig;
6
7/// Static mutex holding the global Burn configuration, initialized as `None`.
8static BURN_GLOBAL_CONFIG: spin::Mutex<Option<Arc<BurnConfig>>> = spin::Mutex::new(None);
9
10/// Represents the global configuration for Burn.
11#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
12pub struct BurnConfig {
13    /// Configuration for operation fusion.
14    #[serde(default)]
15    pub fusion: FusionConfig,
16
17    /// Configuration for autodiff.
18    #[serde(default)]
19    pub autodiff: AutodiffConfig,
20}
21
22impl RuntimeConfig for BurnConfig {
23    fn storage() -> &'static spin::Mutex<Option<Arc<Self>>> {
24        &BURN_GLOBAL_CONFIG
25    }
26
27    fn file_names() -> &'static [&'static str] {
28        &["burn.toml", "Burn.toml"]
29    }
30
31    // Match cubecl-common's `std_io` cfg: only available on platforms where
32    // the trait method exists. See cubecl-common's build.rs.
33    #[cfg(all(
34        feature = "std",
35        any(
36            target_os = "windows",
37            target_os = "linux",
38            target_os = "macos",
39            target_os = "android"
40        )
41    ))]
42    fn override_from_env(mut self) -> Self {
43        use super::fusion::FusionLogLevel;
44
45        if let Ok(val) = std::env::var("BURN_FUSION_LOG") {
46            let level = match val.to_ascii_lowercase().as_str() {
47                "disabled" | "off" | "0" => FusionLogLevel::Disabled,
48                "basic" => FusionLogLevel::Basic,
49                "medium" => FusionLogLevel::Medium,
50                "full" | "1" => FusionLogLevel::Full,
51                _ => self.fusion.logger.level,
52            };
53            self.fusion.logger.level = level;
54            // Default to stderr so tests can see the output via `cargo test -- --nocapture`.
55            if level != FusionLogLevel::Disabled {
56                self.fusion.logger.stderr = true;
57            }
58        }
59
60        self
61    }
62}