Skip to main content

cubecl_runtime/config/
base.rs

1use crate::config::memory::MemoryConfig;
2use crate::config::streaming::StreamingConfig;
3
4use super::{autotune::AutotuneConfig, compilation::CompilationConfig, profiling::ProfilingConfig};
5use alloc::format;
6use alloc::string::{String, ToString};
7use alloc::sync::Arc;
8use cubecl_common::config::RuntimeConfig;
9
10/// Static mutex holding the global configuration, initialized as `None`.
11static CUBE_GLOBAL_CONFIG: spin::Mutex<Option<Arc<CubeClRuntimeConfig>>> = spin::Mutex::new(None);
12
13/// Represents the global configuration for `CubeCL`, combining profiling, autotuning, and compilation settings.
14#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
15pub struct CubeClRuntimeConfig {
16    /// Configuration for profiling `CubeCL` operations.
17    #[serde(default)]
18    pub profiling: ProfilingConfig,
19
20    /// Configuration for autotuning performance parameters.
21    #[serde(default)]
22    pub autotune: AutotuneConfig,
23
24    /// Configuration for compilation settings.
25    #[serde(default)]
26    pub compilation: CompilationConfig,
27
28    /// Configuration for streaming settings.
29    #[serde(default)]
30    pub streaming: StreamingConfig,
31
32    /// Configuration for memory settings.
33    #[serde(default)]
34    pub memory: MemoryConfig,
35}
36
37impl RuntimeConfig for CubeClRuntimeConfig {
38    fn storage() -> &'static spin::Mutex<Option<Arc<Self>>> {
39        &CUBE_GLOBAL_CONFIG
40    }
41
42    fn file_names() -> &'static [&'static str] {
43        &["cubecl.toml", "CubeCL.toml"]
44    }
45
46    fn section_file_names() -> &'static [(&'static str, &'static str)] {
47        &[("burn.toml", "cubecl"), ("Burn.toml", "cubecl")]
48    }
49
50    #[cfg(std_io)]
51    fn override_from_env(mut self) -> Self {
52        use super::compilation::CompilationLogLevel;
53        use crate::config::{
54            autotune::{AutotuneLevel, AutotuneLogLevel},
55            profiling::ProfilingLogLevel,
56        };
57
58        if let Ok(val) = std::env::var("CUBECL_DEBUG_LOG") {
59            self.compilation.logger.level = CompilationLogLevel::Full;
60            self.profiling.logger.level = ProfilingLogLevel::Medium;
61            self.autotune.logger.level = AutotuneLogLevel::Full;
62
63            match val.as_str() {
64                "stdout" => {
65                    self.compilation.logger.stdout = true;
66                    self.profiling.logger.stdout = true;
67                    self.autotune.logger.stdout = true;
68                }
69                "stderr" => {
70                    self.compilation.logger.stderr = true;
71                    self.profiling.logger.stderr = true;
72                    self.autotune.logger.stderr = true;
73                }
74                "1" | "true" => {
75                    let file_path = "/tmp/cubecl.log";
76                    self.compilation.logger.file = Some(file_path.into());
77                    self.profiling.logger.file = Some(file_path.into());
78                    self.autotune.logger.file = Some(file_path.into());
79                }
80                "0" | "false" => {
81                    self.compilation.logger.level = CompilationLogLevel::Disabled;
82                    self.profiling.logger.level = ProfilingLogLevel::Disabled;
83                    self.autotune.logger.level = AutotuneLogLevel::Disabled;
84                }
85                file_path => {
86                    self.compilation.logger.file = Some(file_path.into());
87                    self.profiling.logger.file = Some(file_path.into());
88                    self.autotune.logger.file = Some(file_path.into());
89                }
90            }
91        };
92
93        if let Ok(val) = std::env::var("CUBECL_DEBUG_OPTION") {
94            match val.as_str() {
95                "debug" => {
96                    self.compilation.logger.level = CompilationLogLevel::Full;
97                    self.profiling.logger.level = ProfilingLogLevel::Medium;
98                    self.autotune.logger.level = AutotuneLogLevel::Full;
99                }
100                "debug-full" => {
101                    self.compilation.logger.level = CompilationLogLevel::Full;
102                    self.profiling.logger.level = ProfilingLogLevel::Full;
103                    self.autotune.logger.level = AutotuneLogLevel::Full;
104                }
105                "profile" => {
106                    self.profiling.logger.level = ProfilingLogLevel::Basic;
107                }
108                "profile-medium" => {
109                    self.profiling.logger.level = ProfilingLogLevel::Medium;
110                }
111                "profile-full" => {
112                    self.profiling.logger.level = ProfilingLogLevel::Full;
113                }
114                _ => {}
115            }
116        };
117
118        if let Ok(val) = std::env::var("CUBECL_AUTOTUNE_LEVEL") {
119            match val.as_str() {
120                "minimal" | "0" => {
121                    self.autotune.level = AutotuneLevel::Minimal;
122                }
123                "balanced" | "1" => {
124                    self.autotune.level = AutotuneLevel::Balanced;
125                }
126                "extensive" | "2" => {
127                    self.autotune.level = AutotuneLevel::Extensive;
128                }
129                "full" | "3" => {
130                    self.autotune.level = AutotuneLevel::Full;
131                }
132                _ => {}
133            }
134        }
135
136        self
137    }
138}
139
140#[derive(Clone, Copy, Debug)]
141/// How to format cubecl type names.
142pub enum TypeNameFormatLevel {
143    /// No formatting apply, full information is included.
144    Full,
145    /// Most information is removed for a small formatted name.
146    Short,
147    /// Balanced info is kept.
148    Balanced,
149}
150
151/// Format a type name with different options.
152pub fn type_name_format(name: &str, level: TypeNameFormatLevel) -> String {
153    match level {
154        TypeNameFormatLevel::Full => name.to_string(),
155        TypeNameFormatLevel::Short => {
156            if let Some(val) = name.split("<").next() {
157                val.split("::").last().unwrap_or(name).to_string()
158            } else {
159                name.to_string()
160            }
161        }
162        TypeNameFormatLevel::Balanced => {
163            let mut split = name.split("<");
164            let before_generic = split.next();
165            let after_generic = split.next();
166
167            let before_generic = match before_generic {
168                None => return name.to_string(),
169                Some(val) => val
170                    .split("::")
171                    .last()
172                    .unwrap_or(val)
173                    .trim()
174                    .replace(">", "")
175                    .to_string(),
176            };
177            let inside_generic = match after_generic {
178                None => return before_generic.to_string(),
179                Some(val) => {
180                    let mut val = val.to_string();
181                    for s in split {
182                        val += "<";
183                        val += s;
184                    }
185                    val
186                }
187            };
188
189            let inside = type_name_list_format(&inside_generic, level);
190
191            format!("{before_generic}{inside}")
192        }
193    }
194}
195
196fn type_name_list_format(name: &str, level: TypeNameFormatLevel) -> String {
197    let mut acc = String::new();
198    let splits = name.split(", ");
199
200    for a in splits {
201        acc += " | ";
202        acc += &type_name_format(a, level);
203    }
204
205    acc
206}
207
208#[cfg(test)]
209mod test {
210    use super::*;
211
212    #[test_log::test]
213    fn test_format_name() {
214        let full_name = "burn_cubecl::kernel::unary_numeric::unary_numeric::UnaryNumeric<f32, burn_cubecl::tensor::base::CubeTensor<_>::copy::Copy, cubecl_cuda::runtime::CudaRuntime>";
215        let name = type_name_format(full_name, TypeNameFormatLevel::Balanced);
216
217        assert_eq!(name, "UnaryNumeric | f32 | CubeTensor | Copy | CudaRuntime");
218    }
219}