cubecl_runtime/config/
base.rs

1use super::{autotune::AutotuneConfig, compilation::CompilationConfig, profiling::ProfilingConfig};
2use alloc::format;
3use alloc::string::{String, ToString};
4use alloc::sync::Arc;
5
6/// Static mutex holding the global configuration, initialized as `None`.
7static CUBE_GLOBAL_CONFIG: spin::Mutex<Option<Arc<GlobalConfig>>> = spin::Mutex::new(None);
8
9/// Represents the global configuration for CubeCL, combining profiling, autotuning, and compilation settings.
10#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
11pub struct GlobalConfig {
12    /// Configuration for profiling CubeCL operations.
13    #[serde(default)]
14    pub profiling: ProfilingConfig,
15
16    /// Configuration for autotuning performance parameters.
17    #[serde(default)]
18    pub autotune: AutotuneConfig,
19
20    /// Configuration for compilation settings.
21    #[serde(default)]
22    pub compilation: CompilationConfig,
23}
24
25impl GlobalConfig {
26    /// Retrieves the current global configuration, loading it from the current directory if not set.
27    ///
28    /// If no configuration is set, it attempts to load one from `cubecl.toml` or `CubeCL.toml` in the
29    /// current directory or its parents. If no file is found, a default configuration is used.
30    ///
31    /// # Notes
32    ///
33    /// Calling this function is somewhat expensive, because of a global static lock. The config format
34    /// is optimized for parsing, not for consumption. A good practice is to use a local static atomic
35    /// value that you can populate with the appropriate value from the global config during
36    /// initialization of the atomic value.
37    ///
38    /// For example, the autotune level uses a [core::sync::atomic::AtomicI32] with an initial
39    /// value of `-1` to indicate an uninitialized state. It is then set to the proper value based on
40    /// the [super::autotune::AutotuneLevel] config. All subsequent fetches of the value are
41    /// lock-free.
42    pub fn get() -> Arc<Self> {
43        let mut state = CUBE_GLOBAL_CONFIG.lock();
44        if state.as_ref().is_none() {
45            cfg_if::cfg_if! {
46                if #[cfg(std_io)]  {
47                    let config = Self::from_current_dir();
48                    let config = config.override_from_env();
49                } else {
50                    let config = Self::default();
51                }
52            }
53
54            *state = Some(Arc::new(config));
55        }
56
57        state.as_ref().cloned().unwrap()
58    }
59
60    #[cfg(std_io)]
61    /// Save the default configuration to the provided file path.
62    pub fn save_default<P: AsRef<std::path::Path>>(path: P) -> std::io::Result<()> {
63        use std::io::Write;
64
65        let config = Self::get();
66        let content =
67            toml::to_string_pretty(config.as_ref()).expect("Default config should be serializable");
68        let mut file = std::fs::File::create(path)?;
69        file.write_all(content.as_bytes())?;
70
71        Ok(())
72    }
73
74    /// Sets the global configuration to the provided value.
75    ///
76    /// # Panics
77    /// Panics if the configuration has already been set or read, as it cannot be overridden.
78    ///
79    /// # Warning
80    /// This method must be called at the start of the program, before any calls to `get`. Attempting
81    /// to set the configuration after it has been initialized will cause a panic.
82    pub fn set(config: Self) {
83        let mut state = CUBE_GLOBAL_CONFIG.lock();
84        if state.is_some() {
85            panic!("Cannot set the global configuration multiple times.");
86        }
87        *state = Some(Arc::new(config));
88    }
89
90    #[cfg(std_io)]
91    /// Overrides configuration fields based on environment variables.
92    pub fn override_from_env(mut self) -> Self {
93        use super::compilation::CompilationLogLevel;
94        use crate::config::{
95            autotune::{AutotuneLevel, AutotuneLogLevel},
96            profiling::ProfilingLogLevel,
97        };
98
99        if let Ok(val) = std::env::var("CUBECL_DEBUG_LOG") {
100            self.compilation.logger.level = CompilationLogLevel::Full;
101            self.profiling.logger.level = ProfilingLogLevel::Medium;
102            self.autotune.logger.level = AutotuneLogLevel::Full;
103
104            match val.as_str() {
105                "stdout" => {
106                    self.compilation.logger.stdout = true;
107                    self.profiling.logger.stdout = true;
108                    self.autotune.logger.stdout = true;
109                }
110                "stderr" => {
111                    self.compilation.logger.stderr = true;
112                    self.profiling.logger.stderr = true;
113                    self.autotune.logger.stderr = true;
114                }
115                "1" | "true" => {
116                    let file_path = "/tmp/cubecl.log";
117                    self.compilation.logger.file = Some(file_path.into());
118                    self.profiling.logger.file = Some(file_path.into());
119                    self.autotune.logger.file = Some(file_path.into());
120                }
121                "0" | "false" => {
122                    self.compilation.logger.level = CompilationLogLevel::Disabled;
123                    self.profiling.logger.level = ProfilingLogLevel::Disabled;
124                    self.autotune.logger.level = AutotuneLogLevel::Disabled;
125                }
126                file_path => {
127                    self.compilation.logger.file = Some(file_path.into());
128                    self.profiling.logger.file = Some(file_path.into());
129                    self.autotune.logger.file = Some(file_path.into());
130                }
131            }
132        };
133
134        if let Ok(val) = std::env::var("CUBECL_DEBUG_OPTION") {
135            match val.as_str() {
136                "debug" => {
137                    self.compilation.logger.level = CompilationLogLevel::Full;
138                    self.profiling.logger.level = ProfilingLogLevel::Medium;
139                    self.autotune.logger.level = AutotuneLogLevel::Full;
140                }
141                "debug-full" => {
142                    self.compilation.logger.level = CompilationLogLevel::Full;
143                    self.profiling.logger.level = ProfilingLogLevel::Full;
144                    self.autotune.logger.level = AutotuneLogLevel::Full;
145                }
146                "profile" => {
147                    self.profiling.logger.level = ProfilingLogLevel::Basic;
148                }
149                "profile-medium" => {
150                    self.profiling.logger.level = ProfilingLogLevel::Medium;
151                }
152                "profile-full" => {
153                    self.profiling.logger.level = ProfilingLogLevel::Full;
154                }
155                _ => {}
156            }
157        };
158
159        if let Ok(val) = std::env::var("CUBECL_AUTOTUNE_LEVEL") {
160            match val.as_str() {
161                "minimal" | "0" => {
162                    self.autotune.level = AutotuneLevel::Minimal;
163                }
164                "balanced" | "1" => {
165                    self.autotune.level = AutotuneLevel::Balanced;
166                }
167                "extensive" | "2" => {
168                    self.autotune.level = AutotuneLevel::Extensive;
169                }
170                "full" | "3" => {
171                    self.autotune.level = AutotuneLevel::Full;
172                }
173                _ => {}
174            }
175        }
176
177        self
178    }
179
180    // Loads configuration from `cubecl.toml` or `CubeCL.toml` in the current directory or its parents.
181    //
182    // Traverses up the directory tree until a valid configuration file is found or the root is reached.
183    // Returns a default configuration if no file is found.
184    #[cfg(std_io)]
185    fn from_current_dir() -> Self {
186        let mut dir = std::env::current_dir().unwrap();
187
188        loop {
189            if let Ok(content) = Self::from_file_path(dir.join("cubecl.toml")) {
190                return content;
191            }
192
193            if let Ok(content) = Self::from_file_path(dir.join("CubeCL.toml")) {
194                return content;
195            }
196
197            if !dir.pop() {
198                break;
199            }
200        }
201
202        Self::default()
203    }
204
205    // Loads configuration from a specified file path.
206    #[cfg(std_io)]
207    fn from_file_path<P: AsRef<std::path::Path>>(path: P) -> std::io::Result<Self> {
208        let content = std::fs::read_to_string(path)?;
209        let config: Self = match toml::from_str(&content) {
210            Ok(val) => val,
211            Err(err) => panic!("The file provided doesn't have the right format => {err:?}"),
212        };
213
214        Ok(config)
215    }
216}
217
218#[derive(Clone, Copy, Debug)]
219/// How to format cubecl type names.
220pub enum TypeNameFormatLevel {
221    /// No formatting apply, full information is included.
222    Full,
223    /// Most information is removed for a small formatted name.
224    Short,
225    /// Balanced info is kept.
226    Balanced,
227}
228
229/// Format a type name with different options.
230pub fn type_name_format(name: &str, level: TypeNameFormatLevel) -> String {
231    match level {
232        TypeNameFormatLevel::Full => name.to_string(),
233        TypeNameFormatLevel::Short => {
234            if let Some(val) = name.split("<").next() {
235                val.split("::").last().unwrap_or(name).to_string()
236            } else {
237                name.to_string()
238            }
239        }
240        TypeNameFormatLevel::Balanced => {
241            let mut split = name.split("<");
242            let before_generic = split.next();
243            let after_generic = split.next();
244
245            let before_generic = match before_generic {
246                None => return name.to_string(),
247                Some(val) => val
248                    .split("::")
249                    .last()
250                    .unwrap_or(val)
251                    .trim()
252                    .replace(">", "")
253                    .to_string(),
254            };
255            let inside_generic = match after_generic {
256                None => return before_generic.to_string(),
257                Some(val) => {
258                    let mut val = val.to_string();
259                    for s in split {
260                        val += "<";
261                        val += s;
262                    }
263                    val
264                }
265            };
266
267            let inside = type_name_list_format(&inside_generic, level);
268
269            format!("{before_generic}{inside}")
270        }
271    }
272}
273
274fn type_name_list_format(name: &str, level: TypeNameFormatLevel) -> String {
275    let mut acc = String::new();
276    let splits = name.split(", ");
277
278    for a in splits {
279        acc += " | ";
280        acc += &type_name_format(a, level);
281    }
282
283    acc
284}
285
286#[cfg(test)]
287mod test {
288    use super::*;
289
290    #[test]
291    fn test_format_name() {
292        let full_name = "burn_cubecl::kernel::unary_numeric::unary_numeric::UnaryNumeric<f32, burn_cubecl::tensor::base::CubeTensor<_>::copy::Copy, cubecl_cuda::runtime::CudaRuntime>";
293        let name = type_name_format(full_name, TypeNameFormatLevel::Balanced);
294
295        assert_eq!(name, "UnaryNumeric | f32 | CubeTensor | Copy | CudaRuntime");
296    }
297}