cubecl_runtime/config/
base.rs1use crate::config::memory::MemoryConfig;
2use crate::config::streaming::StreamingConfig;
3
4use super::{autotune::AutotuneConfig, compilation::CompilationConfig, profiling::ProfilingConfig};
5use alloc::format;
6use alloc::string::{String, ToString};
7use alloc::sync::Arc;
8use cubecl_common::config::RuntimeConfig;
9
10static CUBE_GLOBAL_CONFIG: spin::Mutex<Option<Arc<CubeClRuntimeConfig>>> = spin::Mutex::new(None);
12
13#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
15pub struct CubeClRuntimeConfig {
16 #[serde(default)]
18 pub profiling: ProfilingConfig,
19
20 #[serde(default)]
22 pub autotune: AutotuneConfig,
23
24 #[serde(default)]
26 pub compilation: CompilationConfig,
27
28 #[serde(default)]
30 pub streaming: StreamingConfig,
31
32 #[serde(default)]
34 pub memory: MemoryConfig,
35}
36
37impl RuntimeConfig for CubeClRuntimeConfig {
38 fn storage() -> &'static spin::Mutex<Option<Arc<Self>>> {
39 &CUBE_GLOBAL_CONFIG
40 }
41
42 fn file_names() -> &'static [&'static str] {
43 &["cubecl.toml", "CubeCL.toml"]
44 }
45
46 fn section_file_names() -> &'static [(&'static str, &'static str)] {
47 &[("burn.toml", "cubecl"), ("Burn.toml", "cubecl")]
48 }
49
50 #[cfg(std_io)]
51 fn override_from_env(mut self) -> Self {
52 use super::compilation::CompilationLogLevel;
53 use crate::config::{
54 autotune::{AutotuneLevel, AutotuneLogLevel},
55 profiling::ProfilingLogLevel,
56 };
57
58 if let Ok(val) = std::env::var("CUBECL_DEBUG_LOG") {
59 self.compilation.logger.level = CompilationLogLevel::Full;
60 self.profiling.logger.level = ProfilingLogLevel::Medium;
61 self.autotune.logger.level = AutotuneLogLevel::Full;
62
63 match val.as_str() {
64 "stdout" => {
65 self.compilation.logger.stdout = true;
66 self.profiling.logger.stdout = true;
67 self.autotune.logger.stdout = true;
68 }
69 "stderr" => {
70 self.compilation.logger.stderr = true;
71 self.profiling.logger.stderr = true;
72 self.autotune.logger.stderr = true;
73 }
74 "1" | "true" => {
75 let file_path = "/tmp/cubecl.log";
76 self.compilation.logger.file = Some(file_path.into());
77 self.profiling.logger.file = Some(file_path.into());
78 self.autotune.logger.file = Some(file_path.into());
79 }
80 "0" | "false" => {
81 self.compilation.logger.level = CompilationLogLevel::Disabled;
82 self.profiling.logger.level = ProfilingLogLevel::Disabled;
83 self.autotune.logger.level = AutotuneLogLevel::Disabled;
84 }
85 file_path => {
86 self.compilation.logger.file = Some(file_path.into());
87 self.profiling.logger.file = Some(file_path.into());
88 self.autotune.logger.file = Some(file_path.into());
89 }
90 }
91 };
92
93 if let Ok(val) = std::env::var("CUBECL_DEBUG_OPTION") {
94 match val.as_str() {
95 "debug" => {
96 self.compilation.logger.level = CompilationLogLevel::Full;
97 self.profiling.logger.level = ProfilingLogLevel::Medium;
98 self.autotune.logger.level = AutotuneLogLevel::Full;
99 }
100 "debug-full" => {
101 self.compilation.logger.level = CompilationLogLevel::Full;
102 self.profiling.logger.level = ProfilingLogLevel::Full;
103 self.autotune.logger.level = AutotuneLogLevel::Full;
104 }
105 "profile" => {
106 self.profiling.logger.level = ProfilingLogLevel::Basic;
107 }
108 "profile-medium" => {
109 self.profiling.logger.level = ProfilingLogLevel::Medium;
110 }
111 "profile-full" => {
112 self.profiling.logger.level = ProfilingLogLevel::Full;
113 }
114 _ => {}
115 }
116 };
117
118 if let Ok(val) = std::env::var("CUBECL_AUTOTUNE_LEVEL") {
119 match val.as_str() {
120 "minimal" | "0" => {
121 self.autotune.level = AutotuneLevel::Minimal;
122 }
123 "balanced" | "1" => {
124 self.autotune.level = AutotuneLevel::Balanced;
125 }
126 "extensive" | "2" => {
127 self.autotune.level = AutotuneLevel::Extensive;
128 }
129 "full" | "3" => {
130 self.autotune.level = AutotuneLevel::Full;
131 }
132 _ => {}
133 }
134 }
135
136 self
137 }
138}
139
140#[derive(Clone, Copy, Debug)]
141pub enum TypeNameFormatLevel {
143 Full,
145 Short,
147 Balanced,
149}
150
151pub fn type_name_format(name: &str, level: TypeNameFormatLevel) -> String {
153 match level {
154 TypeNameFormatLevel::Full => name.to_string(),
155 TypeNameFormatLevel::Short => {
156 if let Some(val) = name.split("<").next() {
157 val.split("::").last().unwrap_or(name).to_string()
158 } else {
159 name.to_string()
160 }
161 }
162 TypeNameFormatLevel::Balanced => {
163 let mut split = name.split("<");
164 let before_generic = split.next();
165 let after_generic = split.next();
166
167 let before_generic = match before_generic {
168 None => return name.to_string(),
169 Some(val) => val
170 .split("::")
171 .last()
172 .unwrap_or(val)
173 .trim()
174 .replace(">", "")
175 .to_string(),
176 };
177 let inside_generic = match after_generic {
178 None => return before_generic.to_string(),
179 Some(val) => {
180 let mut val = val.to_string();
181 for s in split {
182 val += "<";
183 val += s;
184 }
185 val
186 }
187 };
188
189 let inside = type_name_list_format(&inside_generic, level);
190
191 format!("{before_generic}{inside}")
192 }
193 }
194}
195
196fn type_name_list_format(name: &str, level: TypeNameFormatLevel) -> String {
197 let mut acc = String::new();
198 let splits = name.split(", ");
199
200 for a in splits {
201 acc += " | ";
202 acc += &type_name_format(a, level);
203 }
204
205 acc
206}
207
208#[cfg(test)]
209mod test {
210 use super::*;
211
212 #[test_log::test]
213 fn test_format_name() {
214 let full_name = "burn_cubecl::kernel::unary_numeric::unary_numeric::UnaryNumeric<f32, burn_cubecl::tensor::base::CubeTensor<_>::copy::Copy, cubecl_cuda::runtime::CudaRuntime>";
215 let name = type_name_format(full_name, TypeNameFormatLevel::Balanced);
216
217 assert_eq!(name, "UnaryNumeric | f32 | CubeTensor | Copy | CudaRuntime");
218 }
219}