cubecl_runtime/config/
base.rs1use crate::config::memory::MemoryConfig;
2use crate::config::streaming::StreamingConfig;
3
4use super::{autotune::AutotuneConfig, compilation::CompilationConfig, profiling::ProfilingConfig};
5use alloc::format;
6use alloc::string::{String, ToString};
7use alloc::sync::Arc;
8
9static CUBE_GLOBAL_CONFIG: spin::Mutex<Option<Arc<GlobalConfig>>> = spin::Mutex::new(None);
11
12#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
14pub struct GlobalConfig {
15 #[serde(default)]
17 pub profiling: ProfilingConfig,
18
19 #[serde(default)]
21 pub autotune: AutotuneConfig,
22
23 #[serde(default)]
25 pub compilation: CompilationConfig,
26
27 #[serde(default)]
29 pub streaming: StreamingConfig,
30
31 #[serde(default)]
33 pub memory: MemoryConfig,
34}
35
36impl GlobalConfig {
37 pub fn get() -> Arc<Self> {
54 let mut state = CUBE_GLOBAL_CONFIG.lock();
55 if state.as_ref().is_none() {
56 cfg_if::cfg_if! {
57 if #[cfg(std_io)] {
58 let config = Self::from_current_dir();
59 let config = config.override_from_env();
60 } else {
61 let config = Self::default();
62 }
63 }
64
65 *state = Some(Arc::new(config));
66 }
67
68 state.as_ref().cloned().unwrap()
69 }
70
71 #[cfg(std_io)]
72 pub fn save_default<P: AsRef<std::path::Path>>(path: P) -> std::io::Result<()> {
74 use std::io::Write;
75
76 let config = Self::get();
77 let content =
78 toml::to_string_pretty(config.as_ref()).expect("Default config should be serializable");
79 let mut file = std::fs::File::create(path)?;
80 file.write_all(content.as_bytes())?;
81
82 Ok(())
83 }
84
85 pub fn set(config: Self) {
94 let mut state = CUBE_GLOBAL_CONFIG.lock();
95 if state.is_some() {
96 panic!("Cannot set the global configuration multiple times.");
97 }
98 *state = Some(Arc::new(config));
99 }
100
101 #[cfg(std_io)]
102 pub fn override_from_env(mut self) -> Self {
104 use super::compilation::CompilationLogLevel;
105 use crate::config::{
106 autotune::{AutotuneLevel, AutotuneLogLevel},
107 profiling::ProfilingLogLevel,
108 };
109
110 if let Ok(val) = std::env::var("CUBECL_DEBUG_LOG") {
111 self.compilation.logger.level = CompilationLogLevel::Full;
112 self.profiling.logger.level = ProfilingLogLevel::Medium;
113 self.autotune.logger.level = AutotuneLogLevel::Full;
114
115 match val.as_str() {
116 "stdout" => {
117 self.compilation.logger.stdout = true;
118 self.profiling.logger.stdout = true;
119 self.autotune.logger.stdout = true;
120 }
121 "stderr" => {
122 self.compilation.logger.stderr = true;
123 self.profiling.logger.stderr = true;
124 self.autotune.logger.stderr = true;
125 }
126 "1" | "true" => {
127 let file_path = "/tmp/cubecl.log";
128 self.compilation.logger.file = Some(file_path.into());
129 self.profiling.logger.file = Some(file_path.into());
130 self.autotune.logger.file = Some(file_path.into());
131 }
132 "0" | "false" => {
133 self.compilation.logger.level = CompilationLogLevel::Disabled;
134 self.profiling.logger.level = ProfilingLogLevel::Disabled;
135 self.autotune.logger.level = AutotuneLogLevel::Disabled;
136 }
137 file_path => {
138 self.compilation.logger.file = Some(file_path.into());
139 self.profiling.logger.file = Some(file_path.into());
140 self.autotune.logger.file = Some(file_path.into());
141 }
142 }
143 };
144
145 if let Ok(val) = std::env::var("CUBECL_DEBUG_OPTION") {
146 match val.as_str() {
147 "debug" => {
148 self.compilation.logger.level = CompilationLogLevel::Full;
149 self.profiling.logger.level = ProfilingLogLevel::Medium;
150 self.autotune.logger.level = AutotuneLogLevel::Full;
151 }
152 "debug-full" => {
153 self.compilation.logger.level = CompilationLogLevel::Full;
154 self.profiling.logger.level = ProfilingLogLevel::Full;
155 self.autotune.logger.level = AutotuneLogLevel::Full;
156 }
157 "profile" => {
158 self.profiling.logger.level = ProfilingLogLevel::Basic;
159 }
160 "profile-medium" => {
161 self.profiling.logger.level = ProfilingLogLevel::Medium;
162 }
163 "profile-full" => {
164 self.profiling.logger.level = ProfilingLogLevel::Full;
165 }
166 _ => {}
167 }
168 };
169
170 if let Ok(val) = std::env::var("CUBECL_AUTOTUNE_LEVEL") {
171 match val.as_str() {
172 "minimal" | "0" => {
173 self.autotune.level = AutotuneLevel::Minimal;
174 }
175 "balanced" | "1" => {
176 self.autotune.level = AutotuneLevel::Balanced;
177 }
178 "extensive" | "2" => {
179 self.autotune.level = AutotuneLevel::Extensive;
180 }
181 "full" | "3" => {
182 self.autotune.level = AutotuneLevel::Full;
183 }
184 _ => {}
185 }
186 }
187
188 self
189 }
190
191 #[cfg(std_io)]
196 fn from_current_dir() -> Self {
197 let mut dir = std::env::current_dir().unwrap();
198
199 loop {
200 if let Ok(content) = Self::from_file_path(dir.join("cubecl.toml")) {
201 return content;
202 }
203
204 if let Ok(content) = Self::from_file_path(dir.join("CubeCL.toml")) {
205 return content;
206 }
207
208 if !dir.pop() {
209 break;
210 }
211 }
212
213 Self::default()
214 }
215
216 #[cfg(std_io)]
218 fn from_file_path<P: AsRef<std::path::Path>>(path: P) -> std::io::Result<Self> {
219 let content = std::fs::read_to_string(path)?;
220 let config: Self = match toml::from_str(&content) {
221 Ok(val) => val,
222 Err(err) => panic!("The file provided doesn't have the right format => {err:?}"),
223 };
224
225 Ok(config)
226 }
227}
228
229#[derive(Clone, Copy, Debug)]
230pub enum TypeNameFormatLevel {
232 Full,
234 Short,
236 Balanced,
238}
239
240pub fn type_name_format(name: &str, level: TypeNameFormatLevel) -> String {
242 match level {
243 TypeNameFormatLevel::Full => name.to_string(),
244 TypeNameFormatLevel::Short => {
245 if let Some(val) = name.split("<").next() {
246 val.split("::").last().unwrap_or(name).to_string()
247 } else {
248 name.to_string()
249 }
250 }
251 TypeNameFormatLevel::Balanced => {
252 let mut split = name.split("<");
253 let before_generic = split.next();
254 let after_generic = split.next();
255
256 let before_generic = match before_generic {
257 None => return name.to_string(),
258 Some(val) => val
259 .split("::")
260 .last()
261 .unwrap_or(val)
262 .trim()
263 .replace(">", "")
264 .to_string(),
265 };
266 let inside_generic = match after_generic {
267 None => return before_generic.to_string(),
268 Some(val) => {
269 let mut val = val.to_string();
270 for s in split {
271 val += "<";
272 val += s;
273 }
274 val
275 }
276 };
277
278 let inside = type_name_list_format(&inside_generic, level);
279
280 format!("{before_generic}{inside}")
281 }
282 }
283}
284
285fn type_name_list_format(name: &str, level: TypeNameFormatLevel) -> String {
286 let mut acc = String::new();
287 let splits = name.split(", ");
288
289 for a in splits {
290 acc += " | ";
291 acc += &type_name_format(a, level);
292 }
293
294 acc
295}
296
297#[cfg(test)]
298mod test {
299 use super::*;
300
301 #[test]
302 fn test_format_name() {
303 let full_name = "burn_cubecl::kernel::unary_numeric::unary_numeric::UnaryNumeric<f32, burn_cubecl::tensor::base::CubeTensor<_>::copy::Copy, cubecl_cuda::runtime::CudaRuntime>";
304 let name = type_name_format(full_name, TypeNameFormatLevel::Balanced);
305
306 assert_eq!(name, "UnaryNumeric | f32 | CubeTensor | Copy | CudaRuntime");
307 }
308}