cubecl_runtime/
debug.rs

1use core::fmt::Display;
2
3#[cfg(feature = "std")]
4use std::{
5    fs::{File, OpenOptions},
6    io::{BufWriter, Write},
7    path::PathBuf,
8};
9
10#[cfg(feature = "std")]
11use profile::*;
12
13#[cfg(feature = "std")]
14mod profile {
15    use core::fmt::Display;
16    use std::collections::HashMap;
17
18    #[derive(Debug, Default)]
19    pub(crate) struct Profiled {
20        durations: HashMap<String, ProfileItem>,
21    }
22
23    #[derive(Debug, Default, Clone)]
24    pub(crate) struct ProfileItem {
25        total_duration: core::time::Duration,
26        num_computed: usize,
27    }
28
29    impl Profiled {
30        pub fn update(&mut self, name: &String, duration: core::time::Duration) {
31            let name = if name.contains("\n") {
32                name.split("\n").next().unwrap()
33            } else {
34                name
35            };
36            if let Some(item) = self.durations.get_mut(name) {
37                item.update(duration);
38            } else {
39                self.durations.insert(
40                    name.to_string(),
41                    ProfileItem {
42                        total_duration: duration,
43                        num_computed: 1,
44                    },
45                );
46            }
47        }
48    }
49
50    impl Display for Profiled {
51        fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
52            let header_name = "Name";
53            let header_num_computed = "Num Computed";
54            let header_duration = "Duration";
55            let header_ratio = "Ratio";
56
57            let mut ratio_len = header_ratio.len();
58            let mut name_len = header_name.len();
59            let mut num_computed_len = header_num_computed.len();
60            let mut duration_len = header_duration.len();
61
62            let mut total_duration = core::time::Duration::from_secs(0);
63            let mut total_computed = 0;
64
65            let mut items: Vec<(String, String, String, core::time::Duration)> = self
66                .durations
67                .iter()
68                .map(|(key, item)| {
69                    let name = key.clone();
70                    let num_computed = format!("{}", item.num_computed);
71                    let duration = format!("{:?}", item.total_duration);
72
73                    name_len = usize::max(name_len, name.len());
74                    num_computed_len = usize::max(num_computed_len, num_computed.len());
75                    duration_len = usize::max(duration_len, duration.len());
76
77                    total_duration += item.total_duration;
78                    total_computed += item.num_computed;
79
80                    (name, num_computed, duration, item.total_duration)
81                })
82                .collect();
83
84            let total_duration_fmt = format!("{:?}", total_duration);
85            let total_compute_fmt = format!("{:?}", total_computed);
86            let total_ratio_fmt = "100 %";
87
88            duration_len = usize::max(duration_len, total_duration_fmt.len());
89            num_computed_len = usize::max(num_computed_len, total_compute_fmt.len());
90            ratio_len = usize::max(ratio_len, total_ratio_fmt.len());
91
92            let line_length = name_len + duration_len + num_computed_len + ratio_len + 11;
93
94            let write_line = |char: &str, f: &mut core::fmt::Formatter<'_>| {
95                writeln!(f, "|{}| ", char.repeat(line_length))
96            };
97            items.sort_by(|(_, _, _, a), (_, _, _, b)| b.cmp(a));
98
99            write_line("⎺", f)?;
100
101            writeln!(
102            f,
103            "| {:<width_name$} | {:<width_duration$} | {:<width_num_computed$} | {:<width_ratio$} |",
104            header_name,
105            header_duration,
106            header_num_computed,
107            header_ratio,
108            width_name = name_len,
109            width_duration = duration_len,
110            width_num_computed = num_computed_len,
111            width_ratio = ratio_len,
112        )?;
113
114            write_line("⎼", f)?;
115
116            for (name, num_computed, duration, num) in items {
117                let ratio = (100 * num.as_micros()) / total_duration.as_micros();
118
119                writeln!(
120                f,
121                "| {:<width_name$} | {:<width_duration$} | {:<width_num_computed$} | {:<width_ratio$} |",
122                name,
123                duration,
124                num_computed,
125                format!("{} %", ratio),
126                width_name = name_len,
127                width_duration = duration_len,
128                width_num_computed = num_computed_len,
129                width_ratio = ratio_len,
130            )?;
131            }
132
133            write_line("⎼", f)?;
134
135            writeln!(
136                f,
137                "| {:<width_name$} | {:<width_duration$} | {:<width_num_computed$} | {:<width_ratio$} |",
138                "Total",
139                total_duration_fmt,
140                total_compute_fmt,
141                total_ratio_fmt,
142                width_name = name_len,
143                width_duration = duration_len,
144                width_num_computed = num_computed_len,
145                width_ratio = ratio_len,
146            )?;
147
148            write_line("⎯", f)?;
149
150            Ok(())
151        }
152    }
153
154    impl ProfileItem {
155        pub fn update(&mut self, duration: core::time::Duration) {
156            self.total_duration += duration;
157            self.num_computed += 1;
158        }
159    }
160}
161
162#[derive(Debug, Copy, Clone)]
163/// Control the amount of info being display when profiling.
164pub enum ProfileLevel {
165    /// Provide only the summary information about kernels being run.
166    Basic,
167    /// Provide the summary information about kernels being run with their trace.
168    Medium,
169    /// Provide more information about kernels being run.
170    Full,
171}
172
173#[derive(Debug)]
174/// The various debugging options available.
175pub enum DebugOptions {
176    /// Debug the compilation.
177    Debug,
178    /// Profile each kernel executed.
179    #[cfg(feature = "std")]
180    Profile(ProfileLevel),
181    /// Enable all options.
182    #[cfg(feature = "std")]
183    All(ProfileLevel),
184}
185
186/// Debugging logger.
187#[derive(Debug, Default)]
188pub struct DebugLogger {
189    kind: DebugLoggerKind,
190    #[cfg(feature = "std")]
191    profiled: Profiled,
192}
193
194/// Debugging logger.
195#[derive(Debug)]
196pub enum DebugLoggerKind {
197    #[cfg(feature = "std")]
198    /// Log debugging information into a file.
199    File(DebugFileLogger, DebugOptions),
200    #[cfg(feature = "std")]
201    /// Log debugging information into standard output.
202    Stdout(DebugOptions),
203    /// Don't log debugging information.
204    None,
205}
206
207impl Default for DebugLoggerKind {
208    fn default() -> Self {
209        Self::new()
210    }
211}
212
213impl DebugLogger {
214    /// Returns the profile level, none if profiling is deactivated.
215    pub fn profile_level(&self) -> Option<ProfileLevel> {
216        self.kind.profile_level()
217    }
218
219    /// Register a profiled task.
220    #[cfg_attr(not(feature = "std"), expect(unused))]
221    pub fn register_profiled<Name>(&mut self, name: Name, duration: core::time::Duration)
222    where
223        Name: Display,
224    {
225        #[cfg(feature = "std")]
226        {
227            let name = name.to_string();
228            self.profiled.update(&name, duration);
229
230            match self.kind.profile_level().unwrap_or(ProfileLevel::Basic) {
231                ProfileLevel::Basic => {}
232                _ => self.kind.register_profiled(name, duration),
233            }
234        }
235    }
236    /// Returns whether the debug logger is activated.
237    pub fn is_activated(&self) -> bool {
238        !matches!(self.kind, DebugLoggerKind::None)
239    }
240    /// Log the argument to a file when the debug logger is activated.
241    pub fn debug<I>(&mut self, arg: I) -> I
242    where
243        I: Display,
244    {
245        self.kind.debug(arg)
246    }
247
248    /// Show the profiling summary if activated and reset its state.
249    pub fn profile_summary(&mut self) {
250        #[cfg(feature = "std")]
251        if self.profile_level().is_some() {
252            let mut profiled = Default::default();
253            core::mem::swap(&mut self.profiled, &mut profiled);
254
255            match &mut self.kind {
256                #[cfg(feature = "std")]
257                DebugLoggerKind::File(file, _) => {
258                    file.log(&format!("{}", profiled));
259                }
260                #[cfg(feature = "std")]
261                DebugLoggerKind::Stdout(_) => println!("{profiled}"),
262                _ => (),
263            }
264        }
265    }
266}
267
268impl DebugLoggerKind {
269    #[cfg(not(feature = "std"))]
270    /// Create a new debug logger.
271    pub fn new() -> Self {
272        Self::None
273    }
274
275    /// Create a new debug logger.
276    #[cfg(feature = "std")]
277    pub fn new() -> Self {
278        let flag = match std::env::var("CUBECL_DEBUG_LOG") {
279            Ok(val) => val,
280            Err(_) => return Self::None,
281        };
282        let level = match std::env::var("CUBECL_DEBUG_OPTION") {
283            Ok(val) => val,
284            Err(_) => "debug|profile".to_string(),
285        };
286
287        let mut debug = false;
288        let mut profile = None;
289        level.as_str().split("|").for_each(|flag| match flag {
290            "debug" => {
291                debug = true;
292            }
293            "profile" => {
294                profile = Some(ProfileLevel::Basic);
295            }
296            "profile-medium" => {
297                profile = Some(ProfileLevel::Medium);
298            }
299            "profile-full" => {
300                profile = Some(ProfileLevel::Full);
301            }
302            _ => {}
303        });
304
305        let option = if let Some(level) = profile {
306            if debug {
307                DebugOptions::All(level)
308            } else {
309                DebugOptions::Profile(level)
310            }
311        } else {
312            DebugOptions::Debug
313        };
314
315        if let Ok(activated) = str::parse::<u8>(&flag) {
316            if activated == 1 {
317                return Self::File(DebugFileLogger::new(None), option);
318            } else {
319                return Self::None;
320            }
321        };
322
323        if let Ok(activated) = str::parse::<bool>(&flag) {
324            if activated {
325                return Self::File(DebugFileLogger::new(None), option);
326            } else {
327                return Self::None;
328            }
329        };
330
331        if let "stdout" = flag.as_str() {
332            Self::Stdout(option)
333        } else {
334            Self::File(DebugFileLogger::new(Some(&flag)), option)
335        }
336    }
337
338    /// Returns the profile level, none if profiling is deactivated.
339    #[cfg(feature = "std")]
340    fn profile_level(&self) -> Option<ProfileLevel> {
341        let option = match self {
342            DebugLoggerKind::File(_, option) => option,
343            DebugLoggerKind::Stdout(option) => option,
344            DebugLoggerKind::None => {
345                return None;
346            }
347        };
348        match option {
349            DebugOptions::Debug => None,
350            DebugOptions::Profile(level) => Some(*level),
351            DebugOptions::All(level) => Some(*level),
352        }
353    }
354
355    /// Returns the profile level, none if profiling is deactivated.
356    #[cfg(not(feature = "std"))]
357    fn profile_level(&self) -> Option<ProfileLevel> {
358        None
359    }
360
361    #[cfg(feature = "std")]
362    fn register_profiled(&mut self, name: String, duration: core::time::Duration) {
363        match self {
364            #[cfg(feature = "std")]
365            DebugLoggerKind::File(file, _) => {
366                file.log(&format!("| {duration:<10?} | {name}"));
367            }
368            #[cfg(feature = "std")]
369            DebugLoggerKind::Stdout(_) => println!("| {duration:<10?} | {name}"),
370            _ => (),
371        }
372    }
373
374    fn debug<I>(&mut self, arg: I) -> I
375    where
376        I: Display,
377    {
378        match self {
379            #[cfg(feature = "std")]
380            DebugLoggerKind::File(file, option) => {
381                match option {
382                    DebugOptions::Debug | DebugOptions::All(_) => {
383                        file.log(&arg);
384                    }
385                    DebugOptions::Profile(_) => (),
386                };
387                arg
388            }
389            #[cfg(feature = "std")]
390            DebugLoggerKind::Stdout(option) => {
391                match option {
392                    DebugOptions::Debug | DebugOptions::All(_) => {
393                        println!("{arg}");
394                    }
395                    DebugOptions::Profile(_) => (),
396                };
397                arg
398            }
399            DebugLoggerKind::None => arg,
400        }
401    }
402}
403
404/// Log debugging information into a file.
405#[cfg(feature = "std")]
406#[derive(Debug)]
407pub struct DebugFileLogger {
408    writer: BufWriter<File>,
409}
410
411#[cfg(feature = "std")]
412impl DebugFileLogger {
413    fn new(file_path: Option<&str>) -> Self {
414        let path = match file_path {
415            Some(path) => PathBuf::from(path),
416            None => PathBuf::from("/tmp/cubecl.log"),
417        };
418
419        let file = OpenOptions::new()
420            .append(true)
421            .create(true)
422            .open(path)
423            .unwrap();
424
425        Self {
426            writer: BufWriter::new(file),
427        }
428    }
429    fn log<S: Display>(&mut self, msg: &S) {
430        writeln!(self.writer, "{msg}").expect("Should be able to log debug information.");
431        self.writer.flush().expect("Can complete write operation.");
432    }
433}