Skip to main content

cubecl_runtime/config/
logger.rs

1use super::GlobalConfig;
2use crate::config::{
3    autotune::AutotuneLogLevel, compilation::CompilationLogLevel, memory::MemoryLogLevel,
4    profiling::ProfilingLogLevel, streaming::StreamingLogLevel,
5};
6use alloc::{string::ToString, sync::Arc, vec::Vec};
7use core::fmt::Display;
8use hashbrown::HashMap;
9
10#[cfg(std_io)]
11use std::{
12    eprintln,
13    fs::{File, OpenOptions},
14    io::{BufWriter, Write},
15    path::PathBuf,
16    println,
17};
18
19/// Configuration for logging in `CubeCL`, parameterized by a log level type.
20///
21/// Note that you can use multiple loggers at the same time.
22#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
23#[serde(bound = "")]
24pub struct LoggerConfig<L: LogLevel> {
25    /// Path to the log file, if file logging is enabled (requires `std` feature).
26    #[serde(default)]
27    #[cfg(std_io)]
28    pub file: Option<PathBuf>,
29
30    /// Whether to append to the log file (true) or overwrite it (false). Defaults to true.
31    ///
32    /// ## Notes
33    ///
34    /// This parameter might get ignored based on other loggers config.
35    #[serde(default = "append_default")]
36    pub append: bool,
37
38    /// Whether to log to standard output.
39    #[serde(default)]
40    pub stdout: bool,
41
42    /// Whether to log to standard error.
43    #[serde(default)]
44    pub stderr: bool,
45
46    /// Optional crate-level logging configuration (e.g., info, debug, trace).
47    #[serde(default)]
48    pub log: Option<LogCrateLevel>,
49
50    /// The log level for this logger, determining verbosity.
51    #[serde(default)]
52    pub level: L,
53}
54
55impl<L: LogLevel> Default for LoggerConfig<L> {
56    fn default() -> Self {
57        Self {
58            #[cfg(std_io)]
59            file: None,
60            append: true,
61            #[cfg(feature = "autotune-checks")]
62            stdout: true,
63            #[cfg(not(feature = "autotune-checks"))]
64            stdout: false,
65            stderr: false,
66            log: None,
67            level: L::default(),
68        }
69    }
70}
71
72/// Log levels using the `log` crate.
73///
74/// This enum defines verbosity levels for crate-level logging.
75#[derive(
76    Clone, Copy, Debug, Default, serde::Serialize, serde::Deserialize, Hash, PartialEq, Eq,
77)]
78pub enum LogCrateLevel {
79    /// Logs informational messages.
80    #[default]
81    #[serde(rename = "info")]
82    Info,
83
84    /// Logs debugging messages.
85    #[serde(rename = "debug")]
86    Debug,
87
88    /// Logs trace-level messages.
89    #[serde(rename = "trace")]
90    Trace,
91}
92
93impl LogLevel for u32 {}
94
95fn append_default() -> bool {
96    true
97}
98
99/// Trait for types that can be used as log levels in `LoggerConfig`.
100pub trait LogLevel:
101    serde::de::DeserializeOwned + serde::Serialize + Clone + Copy + core::fmt::Debug + Default
102{
103}
104
105/// Central logging utility for `CubeCL`, managing multiple log outputs.
106#[derive(Debug)]
107pub struct Logger {
108    /// Collection of logger instances (file, stdout, stderr, or crate-level).
109    loggers: Vec<LoggerKind>,
110
111    /// Indices of loggers used for compilation logging.
112    compilation_index: Vec<usize>,
113
114    /// Indices of loggers used for profiling logging.
115    profiling_index: Vec<usize>,
116
117    /// Indices of loggers used for autotuning logging.
118    autotune_index: Vec<usize>,
119
120    /// Indices of loggers used for streaming logging.
121    streaming_index: Vec<usize>,
122
123    /// Indices of loggers used for memory logging.
124    memory_index: Vec<usize>,
125
126    /// Global configuration for logging settings.
127    pub config: Arc<GlobalConfig>,
128}
129
130impl Default for Logger {
131    fn default() -> Self {
132        Self::new()
133    }
134}
135
136impl Logger {
137    /// Creates a new `Logger` instance based on the global configuration.
138    ///
139    /// Initializes loggers for compilation, profiling, and autotuning based on the settings in
140    /// `GlobalConfig`.
141    ///
142    /// Note that creating a logger is quite expensive.
143    pub fn new() -> Self {
144        let config = GlobalConfig::get();
145        let mut loggers = Vec::new();
146        let mut compilation_index = Vec::new();
147        let mut profiling_index = Vec::new();
148        let mut autotune_index = Vec::new();
149        let mut streaming_index = Vec::new();
150        let mut memory_index = Vec::new();
151
152        #[derive(Hash, PartialEq, Eq)]
153        enum LoggerId {
154            #[cfg(std_io)]
155            File(PathBuf),
156            #[cfg(feature = "std")]
157            Stdout,
158            #[cfg(feature = "std")]
159            Stderr,
160            LogCrate(LogCrateLevel),
161        }
162
163        let mut logger2index = HashMap::<LoggerId, usize>::new();
164
165        fn new_logger<S: Clone, ID: Fn(S) -> LoggerId, LG: Fn(S) -> LoggerKind>(
166            setting_index: &mut Vec<usize>,
167            loggers: &mut Vec<LoggerKind>,
168            logger2index: &mut HashMap<LoggerId, usize>,
169            state: S,
170            func_id: ID,
171            func_logger: LG,
172        ) {
173            let id = func_id(state.clone());
174
175            if let Some(index) = logger2index.get(&id) {
176                setting_index.push(*index);
177            } else {
178                let logger = func_logger(state);
179                let index = loggers.len();
180                logger2index.insert(id, index);
181                loggers.push(logger);
182                setting_index.push(index);
183            }
184        }
185
186        fn register_logger<L: LogLevel>(
187            #[allow(unused_variables)] kind: &LoggerConfig<L>, // not used in no-std
188            #[allow(unused_variables)] append: bool,           // not used in no-std
189            level: Option<LogCrateLevel>,
190            setting_index: &mut Vec<usize>,
191            loggers: &mut Vec<LoggerKind>,
192            logger2index: &mut HashMap<LoggerId, usize>,
193        ) {
194            #[cfg(std_io)]
195            if let Some(file) = &kind.file {
196                new_logger(
197                    setting_index,
198                    loggers,
199                    logger2index,
200                    (file, append),
201                    |(file, _append)| LoggerId::File(file.clone()),
202                    |(file, append)| LoggerKind::File(FileLogger::new(file, append)),
203                );
204            }
205
206            #[cfg(feature = "std")]
207            if kind.stdout {
208                new_logger(
209                    setting_index,
210                    loggers,
211                    logger2index,
212                    (),
213                    |_| LoggerId::Stdout,
214                    |_| LoggerKind::Stdout,
215                );
216            }
217
218            #[cfg(feature = "std")]
219            if kind.stderr {
220                new_logger(
221                    setting_index,
222                    loggers,
223                    logger2index,
224                    (),
225                    |_| LoggerId::Stderr,
226                    |_| LoggerKind::Stderr,
227                );
228            }
229
230            if let Some(level) = level {
231                new_logger(
232                    setting_index,
233                    loggers,
234                    logger2index,
235                    level,
236                    LoggerId::LogCrate,
237                    LoggerKind::Log,
238                );
239            }
240        }
241
242        if let CompilationLogLevel::Disabled = config.compilation.logger.level {
243        } else {
244            register_logger(
245                &config.compilation.logger,
246                config.compilation.logger.append,
247                config.compilation.logger.log,
248                &mut compilation_index,
249                &mut loggers,
250                &mut logger2index,
251            )
252        }
253
254        if let ProfilingLogLevel::Disabled = config.profiling.logger.level {
255        } else {
256            register_logger(
257                &config.profiling.logger,
258                config.profiling.logger.append,
259                config.profiling.logger.log,
260                &mut profiling_index,
261                &mut loggers,
262                &mut logger2index,
263            )
264        }
265
266        if let AutotuneLogLevel::Disabled = config.autotune.logger.level {
267        } else {
268            register_logger(
269                &config.autotune.logger,
270                config.autotune.logger.append,
271                config.autotune.logger.log,
272                &mut autotune_index,
273                &mut loggers,
274                &mut logger2index,
275            )
276        }
277
278        if let StreamingLogLevel::Disabled = config.streaming.logger.level {
279        } else {
280            register_logger(
281                &config.streaming.logger,
282                config.streaming.logger.append,
283                config.streaming.logger.log,
284                &mut streaming_index,
285                &mut loggers,
286                &mut logger2index,
287            )
288        }
289
290        if let MemoryLogLevel::Disabled = config.memory.logger.level {
291        } else {
292            register_logger(
293                &config.memory.logger,
294                config.memory.logger.append,
295                config.memory.logger.log,
296                &mut memory_index,
297                &mut loggers,
298                &mut logger2index,
299            )
300        }
301
302        Self {
303            loggers,
304            compilation_index,
305            profiling_index,
306            autotune_index,
307            streaming_index,
308            memory_index,
309            config,
310        }
311    }
312
313    /// Logs a message for streaming, directing it to all configured streaming loggers.
314    pub fn log_streaming<S: Display>(&mut self, msg: &S) {
315        let length = self.streaming_index.len();
316        if length > 1 {
317            let msg = msg.to_string();
318            for i in 0..length {
319                let index = self.streaming_index[i];
320                self.log(&msg, index)
321            }
322        } else if let Some(index) = self.streaming_index.first() {
323            self.log(&msg, *index)
324        }
325    }
326
327    /// Logs a message for memory, directing it to all configured streaming loggers.
328    pub fn log_memory<S: Display>(&mut self, msg: &S) {
329        let length = self.memory_index.len();
330        if length > 1 {
331            let msg = msg.to_string();
332            for i in 0..length {
333                let index = self.memory_index[i];
334                self.log(&msg, index)
335            }
336        } else if let Some(index) = self.memory_index.first() {
337            self.log(&msg, *index)
338        }
339    }
340
341    /// Logs a message for compilation, directing it to all configured compilation loggers.
342    pub fn log_compilation<S: Display>(&mut self, msg: &S) {
343        let length = self.compilation_index.len();
344        if length > 1 {
345            let msg = msg.to_string();
346            for i in 0..length {
347                let index = self.compilation_index[i];
348                self.log(&msg, index)
349            }
350        } else if let Some(index) = self.compilation_index.first() {
351            self.log(&msg, *index)
352        }
353    }
354
355    /// Logs a message for profiling, directing it to all configured profiling loggers.
356    pub fn log_profiling<S: Display>(&mut self, msg: &S) {
357        let length = self.profiling_index.len();
358        if length > 1 {
359            let msg = msg.to_string();
360            for i in 0..length {
361                let index = self.profiling_index[i];
362                self.log(&msg, index)
363            }
364        } else if let Some(index) = self.profiling_index.first() {
365            self.log(&msg, *index)
366        }
367    }
368
369    /// Logs a message for autotuning, directing it to all configured autotuning loggers.
370    pub fn log_autotune<S: Display>(&mut self, msg: &S) {
371        let length = self.autotune_index.len();
372        if length > 1 {
373            let msg = msg.to_string();
374            for i in 0..length {
375                let index = self.autotune_index[i];
376                self.log(&msg, index)
377            }
378        } else if let Some(index) = self.autotune_index.first() {
379            self.log(&msg, *index)
380        }
381    }
382
383    /// Returns the current streaming log level from the global configuration.
384    pub fn log_level_streaming(&self) -> StreamingLogLevel {
385        self.config.streaming.logger.level
386    }
387
388    /// Returns the current autotune log level from the global configuration.
389    pub fn log_level_autotune(&self) -> AutotuneLogLevel {
390        self.config.autotune.logger.level
391    }
392
393    /// Returns the current compilation log level from the global configuration.
394    pub fn log_level_compilation(&self) -> CompilationLogLevel {
395        self.config.compilation.logger.level
396    }
397
398    /// Returns the current profiling log level from the global configuration.
399    pub fn log_level_profiling(&self) -> ProfilingLogLevel {
400        self.config.profiling.logger.level
401    }
402
403    fn log<S: Display>(&mut self, msg: &S, index: usize) {
404        let logger = &mut self.loggers[index];
405        logger.log(msg);
406    }
407}
408
409/// Represents different types of loggers.
410#[derive(Debug)]
411enum LoggerKind {
412    /// Logs to a file.
413    #[cfg(std_io)]
414    File(FileLogger),
415
416    /// Logs to standard output.
417    #[cfg(feature = "std")]
418    Stdout,
419
420    /// Logs to standard error.
421    #[cfg(feature = "std")]
422    Stderr,
423
424    /// Logs using the `log` crate with a specified level.
425    Log(LogCrateLevel),
426}
427
428impl LoggerKind {
429    fn log<S: Display>(&mut self, msg: &S) {
430        match self {
431            #[cfg(std_io)]
432            LoggerKind::File(file_logger) => file_logger.log(msg),
433            #[cfg(feature = "std")]
434            LoggerKind::Stdout => println!("{msg}"),
435            #[cfg(feature = "std")]
436            LoggerKind::Stderr => eprintln!("{msg}"),
437            LoggerKind::Log(level) => match level {
438                LogCrateLevel::Info => log::info!("{msg}"),
439                LogCrateLevel::Trace => log::debug!("{msg}"),
440                LogCrateLevel::Debug => log::trace!("{msg}"),
441            },
442        }
443    }
444}
445
446/// Logger that writes messages to a file.
447#[derive(Debug)]
448#[cfg(std_io)]
449struct FileLogger {
450    writer: BufWriter<File>,
451}
452
453#[cfg(std_io)]
454impl FileLogger {
455    // Creates a new file logger.
456    fn new(path: &PathBuf, append: bool) -> Self {
457        let file = OpenOptions::new()
458            .write(true)
459            .append(append)
460            .create(true)
461            .open(path)
462            .unwrap();
463
464        Self {
465            writer: BufWriter::new(file),
466        }
467    }
468
469    // Logs a message to the file, flushing the buffer to ensure immediate write.
470    fn log<S: Display>(&mut self, msg: &S) {
471        writeln!(self.writer, "{msg}").expect("Should be able to log debug information.");
472        self.writer.flush().expect("Can complete write operation.");
473    }
474}