1use core::fmt::Display;
2
3#[cfg(feature = "std")]
4use std::{
5 fs::{File, OpenOptions},
6 io::{BufWriter, Write},
7 path::PathBuf,
8};
9
10#[cfg(feature = "std")]
11use profile::*;
12
13#[cfg(feature = "std")]
14mod profile {
15 use core::fmt::Display;
16 use std::collections::HashMap;
17
18 #[derive(Debug, Default)]
19 pub(crate) struct Profiled {
20 durations: HashMap<String, ProfileItem>,
21 }
22
23 #[derive(Debug, Default, Clone)]
24 pub(crate) struct ProfileItem {
25 total_duration: core::time::Duration,
26 num_computed: usize,
27 }
28
29 impl Profiled {
30 pub fn update(&mut self, name: &String, duration: core::time::Duration) {
31 let name = if name.contains("\n") {
32 name.split("\n").next().unwrap()
33 } else {
34 name
35 };
36 if let Some(item) = self.durations.get_mut(name) {
37 item.update(duration);
38 } else {
39 self.durations.insert(
40 name.to_string(),
41 ProfileItem {
42 total_duration: duration,
43 num_computed: 1,
44 },
45 );
46 }
47 }
48 }
49
50 impl Display for Profiled {
51 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
52 let header_name = "Name";
53 let header_num_computed = "Num Computed";
54 let header_duration = "Duration";
55 let header_ratio = "Ratio";
56
57 let mut ratio_len = header_ratio.len();
58 let mut name_len = header_name.len();
59 let mut num_computed_len = header_num_computed.len();
60 let mut duration_len = header_duration.len();
61
62 let mut total_duration = core::time::Duration::from_secs(0);
63 let mut total_computed = 0;
64
65 let mut items: Vec<(String, String, String, core::time::Duration)> = self
66 .durations
67 .iter()
68 .map(|(key, item)| {
69 let name = key.clone();
70 let num_computed = format!("{}", item.num_computed);
71 let duration = format!("{:?}", item.total_duration);
72
73 name_len = usize::max(name_len, name.len());
74 num_computed_len = usize::max(num_computed_len, num_computed.len());
75 duration_len = usize::max(duration_len, duration.len());
76
77 total_duration += item.total_duration;
78 total_computed += item.num_computed;
79
80 (name, num_computed, duration, item.total_duration)
81 })
82 .collect();
83
84 let total_duration_fmt = format!("{:?}", total_duration);
85 let total_compute_fmt = format!("{:?}", total_computed);
86 let total_ratio_fmt = "100 %";
87
88 duration_len = usize::max(duration_len, total_duration_fmt.len());
89 num_computed_len = usize::max(num_computed_len, total_compute_fmt.len());
90 ratio_len = usize::max(ratio_len, total_ratio_fmt.len());
91
92 let line_length = name_len + duration_len + num_computed_len + ratio_len + 11;
93
94 let write_line = |char: &str, f: &mut core::fmt::Formatter<'_>| {
95 writeln!(f, "|{}| ", char.repeat(line_length))
96 };
97 items.sort_by(|(_, _, _, a), (_, _, _, b)| b.cmp(a));
98
99 write_line("⎺", f)?;
100
101 writeln!(
102 f,
103 "| {:<width_name$} | {:<width_duration$} | {:<width_num_computed$} | {:<width_ratio$} |",
104 header_name,
105 header_duration,
106 header_num_computed,
107 header_ratio,
108 width_name = name_len,
109 width_duration = duration_len,
110 width_num_computed = num_computed_len,
111 width_ratio = ratio_len,
112 )?;
113
114 write_line("⎼", f)?;
115
116 for (name, num_computed, duration, num) in items {
117 let ratio = (100 * num.as_micros()) / total_duration.as_micros();
118
119 writeln!(
120 f,
121 "| {:<width_name$} | {:<width_duration$} | {:<width_num_computed$} | {:<width_ratio$} |",
122 name,
123 duration,
124 num_computed,
125 format!("{} %", ratio),
126 width_name = name_len,
127 width_duration = duration_len,
128 width_num_computed = num_computed_len,
129 width_ratio = ratio_len,
130 )?;
131 }
132
133 write_line("⎼", f)?;
134
135 writeln!(
136 f,
137 "| {:<width_name$} | {:<width_duration$} | {:<width_num_computed$} | {:<width_ratio$} |",
138 "Total",
139 total_duration_fmt,
140 total_compute_fmt,
141 total_ratio_fmt,
142 width_name = name_len,
143 width_duration = duration_len,
144 width_num_computed = num_computed_len,
145 width_ratio = ratio_len,
146 )?;
147
148 write_line("⎯", f)?;
149
150 Ok(())
151 }
152 }
153
154 impl ProfileItem {
155 pub fn update(&mut self, duration: core::time::Duration) {
156 self.total_duration += duration;
157 self.num_computed += 1;
158 }
159 }
160}
161
162#[derive(Debug, Copy, Clone)]
163pub enum ProfileLevel {
165 Basic,
167 Medium,
169 Full,
171}
172
173#[derive(Debug)]
174pub enum DebugOptions {
176 Debug,
178 #[cfg(feature = "std")]
180 Profile(ProfileLevel),
181 #[cfg(feature = "std")]
183 All(ProfileLevel),
184}
185
186#[derive(Debug, Default)]
188pub struct DebugLogger {
189 kind: DebugLoggerKind,
190 #[cfg(feature = "std")]
191 profiled: Profiled,
192}
193
194#[derive(Debug)]
196pub enum DebugLoggerKind {
197 #[cfg(feature = "std")]
198 File(DebugFileLogger, DebugOptions),
200 #[cfg(feature = "std")]
201 Stdout(DebugOptions),
203 None,
205}
206
207impl Default for DebugLoggerKind {
208 fn default() -> Self {
209 Self::new()
210 }
211}
212
213impl DebugLogger {
214 pub fn profile_level(&self) -> Option<ProfileLevel> {
216 self.kind.profile_level()
217 }
218
219 #[cfg_attr(not(feature = "std"), expect(unused))]
221 pub fn register_profiled<Name>(&mut self, name: Name, duration: core::time::Duration)
222 where
223 Name: Display,
224 {
225 #[cfg(feature = "std")]
226 {
227 let name = name.to_string();
228 self.profiled.update(&name, duration);
229
230 match self.kind.profile_level().unwrap_or(ProfileLevel::Basic) {
231 ProfileLevel::Basic => {}
232 _ => self.kind.register_profiled(name, duration),
233 }
234 }
235 }
236 pub fn is_activated(&self) -> bool {
238 !matches!(self.kind, DebugLoggerKind::None)
239 }
240 pub fn debug<I>(&mut self, arg: I) -> I
242 where
243 I: Display,
244 {
245 self.kind.debug(arg)
246 }
247
248 pub fn profile_summary(&mut self) {
250 #[cfg(feature = "std")]
251 if self.profile_level().is_some() {
252 let mut profiled = Default::default();
253 core::mem::swap(&mut self.profiled, &mut profiled);
254
255 match &mut self.kind {
256 #[cfg(feature = "std")]
257 DebugLoggerKind::File(file, _) => {
258 file.log(&format!("{}", profiled));
259 }
260 #[cfg(feature = "std")]
261 DebugLoggerKind::Stdout(_) => println!("{profiled}"),
262 _ => (),
263 }
264 }
265 }
266}
267
268impl DebugLoggerKind {
269 #[cfg(not(feature = "std"))]
270 pub fn new() -> Self {
272 Self::None
273 }
274
275 #[cfg(feature = "std")]
277 pub fn new() -> Self {
278 let flag = match std::env::var("CUBECL_DEBUG_LOG") {
279 Ok(val) => val,
280 Err(_) => return Self::None,
281 };
282 let level = match std::env::var("CUBECL_DEBUG_OPTION") {
283 Ok(val) => val,
284 Err(_) => "debug|profile".to_string(),
285 };
286
287 let mut debug = false;
288 let mut profile = None;
289 level.as_str().split("|").for_each(|flag| match flag {
290 "debug" => {
291 debug = true;
292 }
293 "profile" => {
294 profile = Some(ProfileLevel::Basic);
295 }
296 "profile-medium" => {
297 profile = Some(ProfileLevel::Medium);
298 }
299 "profile-full" => {
300 profile = Some(ProfileLevel::Full);
301 }
302 _ => {}
303 });
304
305 let option = if let Some(level) = profile {
306 if debug {
307 DebugOptions::All(level)
308 } else {
309 DebugOptions::Profile(level)
310 }
311 } else {
312 DebugOptions::Debug
313 };
314
315 if let Ok(activated) = str::parse::<u8>(&flag) {
316 if activated == 1 {
317 return Self::File(DebugFileLogger::new(None), option);
318 } else {
319 return Self::None;
320 }
321 };
322
323 if let Ok(activated) = str::parse::<bool>(&flag) {
324 if activated {
325 return Self::File(DebugFileLogger::new(None), option);
326 } else {
327 return Self::None;
328 }
329 };
330
331 if let "stdout" = flag.as_str() {
332 Self::Stdout(option)
333 } else {
334 Self::File(DebugFileLogger::new(Some(&flag)), option)
335 }
336 }
337
338 #[cfg(feature = "std")]
340 fn profile_level(&self) -> Option<ProfileLevel> {
341 let option = match self {
342 DebugLoggerKind::File(_, option) => option,
343 DebugLoggerKind::Stdout(option) => option,
344 DebugLoggerKind::None => {
345 return None;
346 }
347 };
348 match option {
349 DebugOptions::Debug => None,
350 DebugOptions::Profile(level) => Some(*level),
351 DebugOptions::All(level) => Some(*level),
352 }
353 }
354
355 #[cfg(not(feature = "std"))]
357 fn profile_level(&self) -> Option<ProfileLevel> {
358 None
359 }
360
361 #[cfg(feature = "std")]
362 fn register_profiled(&mut self, name: String, duration: core::time::Duration) {
363 match self {
364 #[cfg(feature = "std")]
365 DebugLoggerKind::File(file, _) => {
366 file.log(&format!("| {duration:<10?} | {name}"));
367 }
368 #[cfg(feature = "std")]
369 DebugLoggerKind::Stdout(_) => println!("| {duration:<10?} | {name}"),
370 _ => (),
371 }
372 }
373
374 fn debug<I>(&mut self, arg: I) -> I
375 where
376 I: Display,
377 {
378 match self {
379 #[cfg(feature = "std")]
380 DebugLoggerKind::File(file, option) => {
381 match option {
382 DebugOptions::Debug | DebugOptions::All(_) => {
383 file.log(&arg);
384 }
385 DebugOptions::Profile(_) => (),
386 };
387 arg
388 }
389 #[cfg(feature = "std")]
390 DebugLoggerKind::Stdout(option) => {
391 match option {
392 DebugOptions::Debug | DebugOptions::All(_) => {
393 println!("{arg}");
394 }
395 DebugOptions::Profile(_) => (),
396 };
397 arg
398 }
399 DebugLoggerKind::None => arg,
400 }
401 }
402}
403
404#[cfg(feature = "std")]
406#[derive(Debug)]
407pub struct DebugFileLogger {
408 writer: BufWriter<File>,
409}
410
411#[cfg(feature = "std")]
412impl DebugFileLogger {
413 fn new(file_path: Option<&str>) -> Self {
414 let path = match file_path {
415 Some(path) => PathBuf::from(path),
416 None => PathBuf::from("/tmp/cubecl.log"),
417 };
418
419 let file = OpenOptions::new()
420 .append(true)
421 .create(true)
422 .open(path)
423 .unwrap();
424
425 Self {
426 writer: BufWriter::new(file),
427 }
428 }
429 fn log<S: Display>(&mut self, msg: &S) {
430 writeln!(self.writer, "{msg}").expect("Should be able to log debug information.");
431 self.writer.flush().expect("Can complete write operation.");
432 }
433}