cubecl_runtime/config/
base.rs1use super::{autotune::AutotuneConfig, compilation::CompilationConfig, profiling::ProfilingConfig};
2use alloc::format;
3use alloc::string::{String, ToString};
4use alloc::sync::Arc;
5
6static CUBE_GLOBAL_CONFIG: spin::Mutex<Option<Arc<GlobalConfig>>> = spin::Mutex::new(None);
8
9#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
11pub struct GlobalConfig {
12 #[serde(default)]
14 pub profiling: ProfilingConfig,
15
16 #[serde(default)]
18 pub autotune: AutotuneConfig,
19
20 #[serde(default)]
22 pub compilation: CompilationConfig,
23}
24
25impl GlobalConfig {
26 pub fn get() -> Arc<Self> {
43 let mut state = CUBE_GLOBAL_CONFIG.lock();
44 if state.as_ref().is_none() {
45 cfg_if::cfg_if! {
46 if #[cfg(std_io)] {
47 let config = Self::from_current_dir();
48 let config = config.override_from_env();
49 } else {
50 let config = Self::default();
51 }
52 }
53
54 *state = Some(Arc::new(config));
55 }
56
57 state.as_ref().cloned().unwrap()
58 }
59
60 #[cfg(std_io)]
61 pub fn save_default<P: AsRef<std::path::Path>>(path: P) -> std::io::Result<()> {
63 use std::io::Write;
64
65 let config = Self::get();
66 let content =
67 toml::to_string_pretty(config.as_ref()).expect("Default config should be serializable");
68 let mut file = std::fs::File::create(path)?;
69 file.write_all(content.as_bytes())?;
70
71 Ok(())
72 }
73
74 pub fn set(config: Self) {
83 let mut state = CUBE_GLOBAL_CONFIG.lock();
84 if state.is_some() {
85 panic!("Cannot set the global configuration multiple times.");
86 }
87 *state = Some(Arc::new(config));
88 }
89
90 #[cfg(std_io)]
91 pub fn override_from_env(mut self) -> Self {
93 use super::compilation::CompilationLogLevel;
94 use crate::config::{
95 autotune::{AutotuneLevel, AutotuneLogLevel},
96 profiling::ProfilingLogLevel,
97 };
98
99 if let Ok(val) = std::env::var("CUBECL_DEBUG_LOG") {
100 self.compilation.logger.level = CompilationLogLevel::Full;
101 self.profiling.logger.level = ProfilingLogLevel::Medium;
102 self.autotune.logger.level = AutotuneLogLevel::Full;
103
104 match val.as_str() {
105 "stdout" => {
106 self.compilation.logger.stdout = true;
107 self.profiling.logger.stdout = true;
108 self.autotune.logger.stdout = true;
109 }
110 "stderr" => {
111 self.compilation.logger.stderr = true;
112 self.profiling.logger.stderr = true;
113 self.autotune.logger.stderr = true;
114 }
115 "1" | "true" => {
116 let file_path = "/tmp/cubecl.log";
117 self.compilation.logger.file = Some(file_path.into());
118 self.profiling.logger.file = Some(file_path.into());
119 self.autotune.logger.file = Some(file_path.into());
120 }
121 "0" | "false" => {
122 self.compilation.logger.level = CompilationLogLevel::Disabled;
123 self.profiling.logger.level = ProfilingLogLevel::Disabled;
124 self.autotune.logger.level = AutotuneLogLevel::Disabled;
125 }
126 file_path => {
127 self.compilation.logger.file = Some(file_path.into());
128 self.profiling.logger.file = Some(file_path.into());
129 self.autotune.logger.file = Some(file_path.into());
130 }
131 }
132 };
133
134 if let Ok(val) = std::env::var("CUBECL_DEBUG_OPTION") {
135 match val.as_str() {
136 "debug" => {
137 self.compilation.logger.level = CompilationLogLevel::Full;
138 self.profiling.logger.level = ProfilingLogLevel::Medium;
139 self.autotune.logger.level = AutotuneLogLevel::Full;
140 }
141 "debug-full" => {
142 self.compilation.logger.level = CompilationLogLevel::Full;
143 self.profiling.logger.level = ProfilingLogLevel::Full;
144 self.autotune.logger.level = AutotuneLogLevel::Full;
145 }
146 "profile" => {
147 self.profiling.logger.level = ProfilingLogLevel::Basic;
148 }
149 "profile-medium" => {
150 self.profiling.logger.level = ProfilingLogLevel::Medium;
151 }
152 "profile-full" => {
153 self.profiling.logger.level = ProfilingLogLevel::Full;
154 }
155 _ => {}
156 }
157 };
158
159 if let Ok(val) = std::env::var("CUBECL_AUTOTUNE_LEVEL") {
160 match val.as_str() {
161 "minimal" | "0" => {
162 self.autotune.level = AutotuneLevel::Minimal;
163 }
164 "balanced" | "1" => {
165 self.autotune.level = AutotuneLevel::Balanced;
166 }
167 "extensive" | "2" => {
168 self.autotune.level = AutotuneLevel::Extensive;
169 }
170 "full" | "3" => {
171 self.autotune.level = AutotuneLevel::Full;
172 }
173 _ => {}
174 }
175 }
176
177 self
178 }
179
180 #[cfg(std_io)]
185 fn from_current_dir() -> Self {
186 let mut dir = std::env::current_dir().unwrap();
187
188 loop {
189 if let Ok(content) = Self::from_file_path(dir.join("cubecl.toml")) {
190 return content;
191 }
192
193 if let Ok(content) = Self::from_file_path(dir.join("CubeCL.toml")) {
194 return content;
195 }
196
197 if !dir.pop() {
198 break;
199 }
200 }
201
202 Self::default()
203 }
204
205 #[cfg(std_io)]
207 fn from_file_path<P: AsRef<std::path::Path>>(path: P) -> std::io::Result<Self> {
208 let content = std::fs::read_to_string(path)?;
209 let config: Self = match toml::from_str(&content) {
210 Ok(val) => val,
211 Err(err) => panic!("The file provided doesn't have the right format => {err:?}"),
212 };
213
214 Ok(config)
215 }
216}
217
218#[derive(Clone, Copy, Debug)]
219pub enum TypeNameFormatLevel {
221 Full,
223 Short,
225 Balanced,
227}
228
229pub fn type_name_format(name: &str, level: TypeNameFormatLevel) -> String {
231 match level {
232 TypeNameFormatLevel::Full => name.to_string(),
233 TypeNameFormatLevel::Short => {
234 if let Some(val) = name.split("<").next() {
235 val.split("::").last().unwrap_or(name).to_string()
236 } else {
237 name.to_string()
238 }
239 }
240 TypeNameFormatLevel::Balanced => {
241 let mut split = name.split("<");
242 let before_generic = split.next();
243 let after_generic = split.next();
244
245 let before_generic = match before_generic {
246 None => return name.to_string(),
247 Some(val) => val
248 .split("::")
249 .last()
250 .unwrap_or(val)
251 .trim()
252 .replace(">", "")
253 .to_string(),
254 };
255 let inside_generic = match after_generic {
256 None => return before_generic.to_string(),
257 Some(val) => {
258 let mut val = val.to_string();
259 for s in split {
260 val += "<";
261 val += s;
262 }
263 val
264 }
265 };
266
267 let inside = type_name_list_format(&inside_generic, level);
268
269 format!("{before_generic}{inside}")
270 }
271 }
272}
273
274fn type_name_list_format(name: &str, level: TypeNameFormatLevel) -> String {
275 let mut acc = String::new();
276 let splits = name.split(", ");
277
278 for a in splits {
279 acc += " | ";
280 acc += &type_name_format(a, level);
281 }
282
283 acc
284}
285
286#[cfg(test)]
287mod test {
288 use super::*;
289
290 #[test]
291 fn test_format_name() {
292 let full_name = "burn_cubecl::kernel::unary_numeric::unary_numeric::UnaryNumeric<f32, burn_cubecl::tensor::base::CubeTensor<_>::copy::Copy, cubecl_cuda::runtime::CudaRuntime>";
293 let name = type_name_format(full_name, TypeNameFormatLevel::Balanced);
294
295 assert_eq!(name, "UnaryNumeric | f32 | CubeTensor | Copy | CudaRuntime");
296 }
297}