Skip to main content

cubecl_runtime/config/
logger.rs

1use super::GlobalConfig;
2use crate::config::{
3    autotune::AutotuneLogLevel, compilation::CompilationLogLevel, memory::MemoryLogLevel,
4    profiling::ProfilingLogLevel, streaming::StreamingLogLevel,
5};
6use alloc::{string::ToString, sync::Arc, vec::Vec};
7use core::fmt::Display;
8use hashbrown::HashMap;
9
10#[cfg(std_io)]
11use std::{
12    fs::{File, OpenOptions},
13    io::{BufWriter, Write},
14    path::PathBuf,
15};
16
17#[cfg(feature = "std")]
18use std::{eprintln, println};
19
20/// Configuration for logging in `CubeCL`, parameterized by a log level type.
21///
22/// Note that you can use multiple loggers at the same time.
23#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
24#[serde(bound = "")]
25pub struct LoggerConfig<L: LogLevel> {
26    /// Path to the log file, if file logging is enabled (requires `std` feature).
27    #[serde(default)]
28    #[cfg(std_io)]
29    pub file: Option<PathBuf>,
30
31    /// Whether to append to the log file (true) or overwrite it (false). Defaults to true.
32    ///
33    /// ## Notes
34    ///
35    /// This parameter might get ignored based on other loggers config.
36    #[serde(default = "append_default")]
37    pub append: bool,
38
39    /// Whether to log to standard output.
40    #[serde(default)]
41    pub stdout: bool,
42
43    /// Whether to log to standard error.
44    #[serde(default)]
45    pub stderr: bool,
46
47    /// Optional crate-level logging configuration (e.g., info, debug, trace).
48    #[serde(default)]
49    pub log: Option<LogCrateLevel>,
50
51    /// The log level for this logger, determining verbosity.
52    #[serde(default)]
53    pub level: L,
54}
55
56impl<L: LogLevel> Default for LoggerConfig<L> {
57    fn default() -> Self {
58        Self {
59            #[cfg(std_io)]
60            file: None,
61            append: true,
62            #[cfg(feature = "autotune-checks")]
63            stdout: true,
64            #[cfg(not(feature = "autotune-checks"))]
65            stdout: false,
66            stderr: false,
67            log: None,
68            level: L::default(),
69        }
70    }
71}
72
73/// Log levels using the `log` crate.
74///
75/// This enum defines verbosity levels for crate-level logging.
76#[derive(
77    Clone, Copy, Debug, Default, serde::Serialize, serde::Deserialize, Hash, PartialEq, Eq,
78)]
79pub enum LogCrateLevel {
80    /// Logs informational messages.
81    #[default]
82    #[serde(rename = "info")]
83    Info,
84
85    /// Logs debugging messages.
86    #[serde(rename = "debug")]
87    Debug,
88
89    /// Logs trace-level messages.
90    #[serde(rename = "trace")]
91    Trace,
92}
93
94impl LogLevel for u32 {}
95
96fn append_default() -> bool {
97    true
98}
99
100/// Trait for types that can be used as log levels in `LoggerConfig`.
101pub trait LogLevel:
102    serde::de::DeserializeOwned + serde::Serialize + Clone + Copy + core::fmt::Debug + Default
103{
104}
105
106/// Central logging utility for `CubeCL`, managing multiple log outputs.
107#[derive(Debug)]
108pub struct Logger {
109    /// Collection of logger instances (file, stdout, stderr, or crate-level).
110    loggers: Vec<LoggerKind>,
111
112    /// Indices of loggers used for compilation logging.
113    compilation_index: Vec<usize>,
114
115    /// Indices of loggers used for profiling logging.
116    profiling_index: Vec<usize>,
117
118    /// Indices of loggers used for autotuning logging.
119    autotune_index: Vec<usize>,
120
121    /// Indices of loggers used for streaming logging.
122    streaming_index: Vec<usize>,
123
124    /// Indices of loggers used for memory logging.
125    memory_index: Vec<usize>,
126
127    /// Global configuration for logging settings.
128    pub config: Arc<GlobalConfig>,
129}
130
131impl Default for Logger {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137impl Logger {
138    /// Creates a new `Logger` instance based on the global configuration.
139    ///
140    /// Initializes loggers for compilation, profiling, and autotuning based on the settings in
141    /// `GlobalConfig`.
142    ///
143    /// Note that creating a logger is quite expensive.
144    pub fn new() -> Self {
145        let config = GlobalConfig::get();
146        let mut loggers = Vec::new();
147        let mut compilation_index = Vec::new();
148        let mut profiling_index = Vec::new();
149        let mut autotune_index = Vec::new();
150        let mut streaming_index = Vec::new();
151        let mut memory_index = Vec::new();
152
153        #[derive(Hash, PartialEq, Eq)]
154        enum LoggerId {
155            #[cfg(std_io)]
156            File(PathBuf),
157            #[cfg(feature = "std")]
158            Stdout,
159            #[cfg(feature = "std")]
160            Stderr,
161            LogCrate(LogCrateLevel),
162        }
163
164        let mut logger2index = HashMap::<LoggerId, usize>::new();
165
166        fn new_logger<S: Clone, ID: Fn(S) -> LoggerId, LG: Fn(S) -> LoggerKind>(
167            setting_index: &mut Vec<usize>,
168            loggers: &mut Vec<LoggerKind>,
169            logger2index: &mut HashMap<LoggerId, usize>,
170            state: S,
171            func_id: ID,
172            func_logger: LG,
173        ) {
174            let id = func_id(state.clone());
175
176            if let Some(index) = logger2index.get(&id) {
177                setting_index.push(*index);
178            } else {
179                let logger = func_logger(state);
180                let index = loggers.len();
181                logger2index.insert(id, index);
182                loggers.push(logger);
183                setting_index.push(index);
184            }
185        }
186
187        fn register_logger<L: LogLevel>(
188            #[allow(unused_variables)] kind: &LoggerConfig<L>, // not used in no-std
189            #[allow(unused_variables)] append: bool,           // not used in no-std
190            level: Option<LogCrateLevel>,
191            setting_index: &mut Vec<usize>,
192            loggers: &mut Vec<LoggerKind>,
193            logger2index: &mut HashMap<LoggerId, usize>,
194        ) {
195            #[cfg(std_io)]
196            if let Some(file) = &kind.file {
197                new_logger(
198                    setting_index,
199                    loggers,
200                    logger2index,
201                    (file, append),
202                    |(file, _append)| LoggerId::File(file.clone()),
203                    |(file, append)| LoggerKind::File(FileLogger::new(file, append)),
204                );
205            }
206
207            #[cfg(feature = "std")]
208            if kind.stdout {
209                new_logger(
210                    setting_index,
211                    loggers,
212                    logger2index,
213                    (),
214                    |_| LoggerId::Stdout,
215                    |_| LoggerKind::Stdout,
216                );
217            }
218
219            #[cfg(feature = "std")]
220            if kind.stderr {
221                new_logger(
222                    setting_index,
223                    loggers,
224                    logger2index,
225                    (),
226                    |_| LoggerId::Stderr,
227                    |_| LoggerKind::Stderr,
228                );
229            }
230
231            if let Some(level) = level {
232                new_logger(
233                    setting_index,
234                    loggers,
235                    logger2index,
236                    level,
237                    LoggerId::LogCrate,
238                    LoggerKind::Log,
239                );
240            }
241        }
242
243        if let CompilationLogLevel::Disabled = config.compilation.logger.level {
244        } else {
245            register_logger(
246                &config.compilation.logger,
247                config.compilation.logger.append,
248                config.compilation.logger.log,
249                &mut compilation_index,
250                &mut loggers,
251                &mut logger2index,
252            )
253        }
254
255        if let ProfilingLogLevel::Disabled = config.profiling.logger.level {
256        } else {
257            register_logger(
258                &config.profiling.logger,
259                config.profiling.logger.append,
260                config.profiling.logger.log,
261                &mut profiling_index,
262                &mut loggers,
263                &mut logger2index,
264            )
265        }
266
267        if let AutotuneLogLevel::Disabled = config.autotune.logger.level {
268        } else {
269            register_logger(
270                &config.autotune.logger,
271                config.autotune.logger.append,
272                config.autotune.logger.log,
273                &mut autotune_index,
274                &mut loggers,
275                &mut logger2index,
276            )
277        }
278
279        if let StreamingLogLevel::Disabled = config.streaming.logger.level {
280        } else {
281            register_logger(
282                &config.streaming.logger,
283                config.streaming.logger.append,
284                config.streaming.logger.log,
285                &mut streaming_index,
286                &mut loggers,
287                &mut logger2index,
288            )
289        }
290
291        if let MemoryLogLevel::Disabled = config.memory.logger.level {
292        } else {
293            register_logger(
294                &config.memory.logger,
295                config.memory.logger.append,
296                config.memory.logger.log,
297                &mut memory_index,
298                &mut loggers,
299                &mut logger2index,
300            )
301        }
302
303        Self {
304            loggers,
305            compilation_index,
306            profiling_index,
307            autotune_index,
308            streaming_index,
309            memory_index,
310            config,
311        }
312    }
313
314    /// Logs a message for streaming, directing it to all configured streaming loggers.
315    pub fn log_streaming<S: Display>(&mut self, msg: &S) {
316        let length = self.streaming_index.len();
317        if length > 1 {
318            let msg = msg.to_string();
319            for i in 0..length {
320                let index = self.streaming_index[i];
321                self.log(&msg, index)
322            }
323        } else if let Some(index) = self.streaming_index.first() {
324            self.log(&msg, *index)
325        }
326    }
327
328    /// Logs a message for memory, directing it to all configured streaming loggers.
329    pub fn log_memory<S: Display>(&mut self, msg: &S) {
330        let length = self.memory_index.len();
331        if length > 1 {
332            let msg = msg.to_string();
333            for i in 0..length {
334                let index = self.memory_index[i];
335                self.log(&msg, index)
336            }
337        } else if let Some(index) = self.memory_index.first() {
338            self.log(&msg, *index)
339        }
340    }
341
342    /// Logs a message for compilation, directing it to all configured compilation loggers.
343    pub fn log_compilation<S: Display>(&mut self, msg: &S) {
344        let length = self.compilation_index.len();
345        if length > 1 {
346            let msg = msg.to_string();
347            for i in 0..length {
348                let index = self.compilation_index[i];
349                self.log(&msg, index)
350            }
351        } else if let Some(index) = self.compilation_index.first() {
352            self.log(&msg, *index)
353        }
354    }
355
356    /// Logs a message for profiling, directing it to all configured profiling loggers.
357    pub fn log_profiling<S: Display>(&mut self, msg: &S) {
358        let length = self.profiling_index.len();
359        if length > 1 {
360            let msg = msg.to_string();
361            for i in 0..length {
362                let index = self.profiling_index[i];
363                self.log(&msg, index)
364            }
365        } else if let Some(index) = self.profiling_index.first() {
366            self.log(&msg, *index)
367        }
368    }
369
370    /// Logs a message for autotuning, directing it to all configured autotuning loggers.
371    pub fn log_autotune<S: Display>(&mut self, msg: &S) {
372        let length = self.autotune_index.len();
373        if length > 1 {
374            let msg = msg.to_string();
375            for i in 0..length {
376                let index = self.autotune_index[i];
377                self.log(&msg, index)
378            }
379        } else if let Some(index) = self.autotune_index.first() {
380            self.log(&msg, *index)
381        }
382    }
383
384    /// Returns the current streaming log level from the global configuration.
385    pub fn log_level_streaming(&self) -> StreamingLogLevel {
386        self.config.streaming.logger.level
387    }
388
389    /// Returns the current autotune log level from the global configuration.
390    pub fn log_level_autotune(&self) -> AutotuneLogLevel {
391        self.config.autotune.logger.level
392    }
393
394    /// Returns the current compilation log level from the global configuration.
395    pub fn log_level_compilation(&self) -> CompilationLogLevel {
396        self.config.compilation.logger.level
397    }
398
399    /// Returns the current profiling log level from the global configuration.
400    pub fn log_level_profiling(&self) -> ProfilingLogLevel {
401        self.config.profiling.logger.level
402    }
403
404    fn log<S: Display>(&mut self, msg: &S, index: usize) {
405        let logger = &mut self.loggers[index];
406        logger.log(msg);
407    }
408}
409
410/// Represents different types of loggers.
411#[derive(Debug)]
412enum LoggerKind {
413    /// Logs to a file.
414    #[cfg(std_io)]
415    File(FileLogger),
416
417    /// Logs to standard output.
418    #[cfg(feature = "std")]
419    Stdout,
420
421    /// Logs to standard error.
422    #[cfg(feature = "std")]
423    Stderr,
424
425    /// Logs using the `log` crate with a specified level.
426    Log(LogCrateLevel),
427}
428
429impl LoggerKind {
430    fn log<S: Display>(&mut self, msg: &S) {
431        match self {
432            #[cfg(std_io)]
433            LoggerKind::File(file_logger) => file_logger.log(msg),
434            #[cfg(feature = "std")]
435            LoggerKind::Stdout => println!("{msg}"),
436            #[cfg(feature = "std")]
437            LoggerKind::Stderr => eprintln!("{msg}"),
438            LoggerKind::Log(level) => match level {
439                LogCrateLevel::Info => log::info!("{msg}"),
440                LogCrateLevel::Trace => log::debug!("{msg}"),
441                LogCrateLevel::Debug => log::trace!("{msg}"),
442            },
443        }
444    }
445}
446
447/// Logger that writes messages to a file.
448#[derive(Debug)]
449#[cfg(std_io)]
450struct FileLogger {
451    writer: BufWriter<File>,
452}
453
454#[cfg(std_io)]
455impl FileLogger {
456    // Creates a new file logger.
457    fn new(path: &PathBuf, append: bool) -> Self {
458        if let Some(parent) = path.parent() {
459            std::fs::create_dir_all(parent).unwrap();
460        }
461        let file = OpenOptions::new()
462            .write(true)
463            .append(append)
464            .create(true)
465            .open(path)
466            .unwrap();
467
468        Self {
469            writer: BufWriter::new(file),
470        }
471    }
472
473    // Logs a message to the file, flushing the buffer to ensure immediate write.
474    fn log<S: Display>(&mut self, msg: &S) {
475        writeln!(self.writer, "{msg}").expect("Should be able to log debug information.");
476        self.writer.flush().expect("Can complete write operation.");
477    }
478}