cubecl_runtime/config/
base.rs

1use crate::config::memory::MemoryConfig;
2use crate::config::streaming::StreamingConfig;
3
4use super::{autotune::AutotuneConfig, compilation::CompilationConfig, profiling::ProfilingConfig};
5use alloc::format;
6use alloc::string::{String, ToString};
7use alloc::sync::Arc;
8
9/// Static mutex holding the global configuration, initialized as `None`.
10static CUBE_GLOBAL_CONFIG: spin::Mutex<Option<Arc<GlobalConfig>>> = spin::Mutex::new(None);
11
12/// Represents the global configuration for CubeCL, combining profiling, autotuning, and compilation settings.
13#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
14pub struct GlobalConfig {
15    /// Configuration for profiling CubeCL operations.
16    #[serde(default)]
17    pub profiling: ProfilingConfig,
18
19    /// Configuration for autotuning performance parameters.
20    #[serde(default)]
21    pub autotune: AutotuneConfig,
22
23    /// Configuration for compilation settings.
24    #[serde(default)]
25    pub compilation: CompilationConfig,
26
27    /// Configuration for streaming settings.
28    #[serde(default)]
29    pub streaming: StreamingConfig,
30
31    /// Configuration for memory settings.
32    #[serde(default)]
33    pub memory: MemoryConfig,
34}
35
36impl GlobalConfig {
37    /// Retrieves the current global configuration, loading it from the current directory if not set.
38    ///
39    /// If no configuration is set, it attempts to load one from `cubecl.toml` or `CubeCL.toml` in the
40    /// current directory or its parents. If no file is found, a default configuration is used.
41    ///
42    /// # Notes
43    ///
44    /// Calling this function is somewhat expensive, because of a global static lock. The config format
45    /// is optimized for parsing, not for consumption. A good practice is to use a local static atomic
46    /// value that you can populate with the appropriate value from the global config during
47    /// initialization of the atomic value.
48    ///
49    /// For example, the autotune level uses a [core::sync::atomic::AtomicI32] with an initial
50    /// value of `-1` to indicate an uninitialized state. It is then set to the proper value based on
51    /// the [super::autotune::AutotuneLevel] config. All subsequent fetches of the value are
52    /// lock-free.
53    pub fn get() -> Arc<Self> {
54        let mut state = CUBE_GLOBAL_CONFIG.lock();
55        if state.as_ref().is_none() {
56            cfg_if::cfg_if! {
57                if #[cfg(std_io)]  {
58                    let config = Self::from_current_dir();
59                    let config = config.override_from_env();
60                } else {
61                    let config = Self::default();
62                }
63            }
64
65            *state = Some(Arc::new(config));
66        }
67
68        state.as_ref().cloned().unwrap()
69    }
70
71    #[cfg(std_io)]
72    /// Save the default configuration to the provided file path.
73    pub fn save_default<P: AsRef<std::path::Path>>(path: P) -> std::io::Result<()> {
74        use std::io::Write;
75
76        let config = Self::get();
77        let content =
78            toml::to_string_pretty(config.as_ref()).expect("Default config should be serializable");
79        let mut file = std::fs::File::create(path)?;
80        file.write_all(content.as_bytes())?;
81
82        Ok(())
83    }
84
85    /// Sets the global configuration to the provided value.
86    ///
87    /// # Panics
88    /// Panics if the configuration has already been set or read, as it cannot be overridden.
89    ///
90    /// # Warning
91    /// This method must be called at the start of the program, before any calls to `get`. Attempting
92    /// to set the configuration after it has been initialized will cause a panic.
93    pub fn set(config: Self) {
94        let mut state = CUBE_GLOBAL_CONFIG.lock();
95        if state.is_some() {
96            panic!("Cannot set the global configuration multiple times.");
97        }
98        *state = Some(Arc::new(config));
99    }
100
101    #[cfg(std_io)]
102    /// Overrides configuration fields based on environment variables.
103    pub fn override_from_env(mut self) -> Self {
104        use super::compilation::CompilationLogLevel;
105        use crate::config::{
106            autotune::{AutotuneLevel, AutotuneLogLevel},
107            profiling::ProfilingLogLevel,
108        };
109
110        if let Ok(val) = std::env::var("CUBECL_DEBUG_LOG") {
111            self.compilation.logger.level = CompilationLogLevel::Full;
112            self.profiling.logger.level = ProfilingLogLevel::Medium;
113            self.autotune.logger.level = AutotuneLogLevel::Full;
114
115            match val.as_str() {
116                "stdout" => {
117                    self.compilation.logger.stdout = true;
118                    self.profiling.logger.stdout = true;
119                    self.autotune.logger.stdout = true;
120                }
121                "stderr" => {
122                    self.compilation.logger.stderr = true;
123                    self.profiling.logger.stderr = true;
124                    self.autotune.logger.stderr = true;
125                }
126                "1" | "true" => {
127                    let file_path = "/tmp/cubecl.log";
128                    self.compilation.logger.file = Some(file_path.into());
129                    self.profiling.logger.file = Some(file_path.into());
130                    self.autotune.logger.file = Some(file_path.into());
131                }
132                "0" | "false" => {
133                    self.compilation.logger.level = CompilationLogLevel::Disabled;
134                    self.profiling.logger.level = ProfilingLogLevel::Disabled;
135                    self.autotune.logger.level = AutotuneLogLevel::Disabled;
136                }
137                file_path => {
138                    self.compilation.logger.file = Some(file_path.into());
139                    self.profiling.logger.file = Some(file_path.into());
140                    self.autotune.logger.file = Some(file_path.into());
141                }
142            }
143        };
144
145        if let Ok(val) = std::env::var("CUBECL_DEBUG_OPTION") {
146            match val.as_str() {
147                "debug" => {
148                    self.compilation.logger.level = CompilationLogLevel::Full;
149                    self.profiling.logger.level = ProfilingLogLevel::Medium;
150                    self.autotune.logger.level = AutotuneLogLevel::Full;
151                }
152                "debug-full" => {
153                    self.compilation.logger.level = CompilationLogLevel::Full;
154                    self.profiling.logger.level = ProfilingLogLevel::Full;
155                    self.autotune.logger.level = AutotuneLogLevel::Full;
156                }
157                "profile" => {
158                    self.profiling.logger.level = ProfilingLogLevel::Basic;
159                }
160                "profile-medium" => {
161                    self.profiling.logger.level = ProfilingLogLevel::Medium;
162                }
163                "profile-full" => {
164                    self.profiling.logger.level = ProfilingLogLevel::Full;
165                }
166                _ => {}
167            }
168        };
169
170        if let Ok(val) = std::env::var("CUBECL_AUTOTUNE_LEVEL") {
171            match val.as_str() {
172                "minimal" | "0" => {
173                    self.autotune.level = AutotuneLevel::Minimal;
174                }
175                "balanced" | "1" => {
176                    self.autotune.level = AutotuneLevel::Balanced;
177                }
178                "extensive" | "2" => {
179                    self.autotune.level = AutotuneLevel::Extensive;
180                }
181                "full" | "3" => {
182                    self.autotune.level = AutotuneLevel::Full;
183                }
184                _ => {}
185            }
186        }
187
188        self
189    }
190
191    // Loads configuration from `cubecl.toml` or `CubeCL.toml` in the current directory or its parents.
192    //
193    // Traverses up the directory tree until a valid configuration file is found or the root is reached.
194    // Returns a default configuration if no file is found.
195    #[cfg(std_io)]
196    fn from_current_dir() -> Self {
197        let mut dir = std::env::current_dir().unwrap();
198
199        loop {
200            if let Ok(content) = Self::from_file_path(dir.join("cubecl.toml")) {
201                return content;
202            }
203
204            if let Ok(content) = Self::from_file_path(dir.join("CubeCL.toml")) {
205                return content;
206            }
207
208            if !dir.pop() {
209                break;
210            }
211        }
212
213        Self::default()
214    }
215
216    // Loads configuration from a specified file path.
217    #[cfg(std_io)]
218    fn from_file_path<P: AsRef<std::path::Path>>(path: P) -> std::io::Result<Self> {
219        let content = std::fs::read_to_string(path)?;
220        let config: Self = match toml::from_str(&content) {
221            Ok(val) => val,
222            Err(err) => panic!("The file provided doesn't have the right format => {err:?}"),
223        };
224
225        Ok(config)
226    }
227}
228
229#[derive(Clone, Copy, Debug)]
230/// How to format cubecl type names.
231pub enum TypeNameFormatLevel {
232    /// No formatting apply, full information is included.
233    Full,
234    /// Most information is removed for a small formatted name.
235    Short,
236    /// Balanced info is kept.
237    Balanced,
238}
239
240/// Format a type name with different options.
241pub fn type_name_format(name: &str, level: TypeNameFormatLevel) -> String {
242    match level {
243        TypeNameFormatLevel::Full => name.to_string(),
244        TypeNameFormatLevel::Short => {
245            if let Some(val) = name.split("<").next() {
246                val.split("::").last().unwrap_or(name).to_string()
247            } else {
248                name.to_string()
249            }
250        }
251        TypeNameFormatLevel::Balanced => {
252            let mut split = name.split("<");
253            let before_generic = split.next();
254            let after_generic = split.next();
255
256            let before_generic = match before_generic {
257                None => return name.to_string(),
258                Some(val) => val
259                    .split("::")
260                    .last()
261                    .unwrap_or(val)
262                    .trim()
263                    .replace(">", "")
264                    .to_string(),
265            };
266            let inside_generic = match after_generic {
267                None => return before_generic.to_string(),
268                Some(val) => {
269                    let mut val = val.to_string();
270                    for s in split {
271                        val += "<";
272                        val += s;
273                    }
274                    val
275                }
276            };
277
278            let inside = type_name_list_format(&inside_generic, level);
279
280            format!("{before_generic}{inside}")
281        }
282    }
283}
284
285fn type_name_list_format(name: &str, level: TypeNameFormatLevel) -> String {
286    let mut acc = String::new();
287    let splits = name.split(", ");
288
289    for a in splits {
290        acc += " | ";
291        acc += &type_name_format(a, level);
292    }
293
294    acc
295}
296
297#[cfg(test)]
298mod test {
299    use super::*;
300
301    #[test]
302    fn test_format_name() {
303        let full_name = "burn_cubecl::kernel::unary_numeric::unary_numeric::UnaryNumeric<f32, burn_cubecl::tensor::base::CubeTensor<_>::copy::Copy, cubecl_cuda::runtime::CudaRuntime>";
304        let name = type_name_format(full_name, TypeNameFormatLevel::Balanced);
305
306        assert_eq!(name, "UnaryNumeric | f32 | CubeTensor | Copy | CudaRuntime");
307    }
308}