cubecl_runtime/logging/
server.rs1use core::fmt::Display;
2
3use crate::config::memory::MemoryLogLevel;
4use crate::config::streaming::StreamingLogLevel;
5use crate::config::{Logger, compilation::CompilationLogLevel, profiling::ProfilingLogLevel};
6use alloc::format;
7use alloc::string::String;
8use alloc::string::ToString;
9use async_channel::{Receiver, Sender};
10use cubecl_common::future::spawn_detached_fut;
11use cubecl_common::profile::ProfileDuration;
12
13use super::{ProfileLevel, Profiled};
14
15enum LogMessage {
16    Execution(String),
17    Compilation(String),
18    Streaming(String),
19    Memory(String),
20    Profile(String, ProfileDuration),
21    ProfileSummary,
22}
23
24#[derive(Debug)]
26pub struct ServerLogger {
27    profile_level: Option<ProfileLevel>,
28    log_compile_info: bool,
29    log_streaming: StreamingLogLevel,
30    log_channel: Option<Sender<LogMessage>>,
31    log_memory: MemoryLogLevel,
32}
33
34impl Default for ServerLogger {
35    fn default() -> Self {
36        let logger = Logger::new();
37
38        let disabled = matches!(
39            logger.config.compilation.logger.level,
40            CompilationLogLevel::Disabled
41        ) && matches!(
42            logger.config.profiling.logger.level,
43            ProfilingLogLevel::Disabled
44        ) && matches!(logger.config.memory.logger.level, MemoryLogLevel::Disabled)
45            && matches!(
46                logger.config.streaming.logger.level,
47                StreamingLogLevel::Disabled
48            );
49
50        if disabled {
51            return Self {
52                profile_level: None,
53                log_compile_info: false,
54                log_streaming: StreamingLogLevel::Disabled,
55                log_channel: None,
56                log_memory: MemoryLogLevel::Disabled,
57            };
58        }
59        let profile_level = match logger.config.profiling.logger.level {
60            ProfilingLogLevel::Disabled => None,
61            ProfilingLogLevel::Minimal => Some(ProfileLevel::ExecutionOnly),
62            ProfilingLogLevel::Basic => Some(ProfileLevel::Basic),
63            ProfilingLogLevel::Medium => Some(ProfileLevel::Medium),
64            ProfilingLogLevel::Full => Some(ProfileLevel::Full),
65        };
66
67        let log_compile_info = match logger.config.compilation.logger.level {
68            CompilationLogLevel::Disabled => false,
69            CompilationLogLevel::Basic => true,
70            CompilationLogLevel::Full => true,
71        };
72        let log_streaming = logger.config.streaming.logger.level;
73        let log_memory = logger.config.memory.logger.level;
74
75        let (send, rec) = async_channel::unbounded();
76
77        let async_logger = AsyncLogger {
79            message: rec,
80            logger,
81            profiled: Default::default(),
82        };
83        spawn_detached_fut(async_logger.process());
85
86        Self {
87            profile_level,
88            log_compile_info,
89            log_streaming,
90            log_memory,
91            log_channel: Some(send),
92        }
93    }
94}
95
96impl ServerLogger {
97    pub fn profile_level(&self) -> Option<ProfileLevel> {
99        self.profile_level
100    }
101
102    pub fn compilation_activated(&self) -> bool {
104        self.log_compile_info
105    }
106
107    pub fn log_compilation<I>(&self, arg: &I)
109    where
110        I: Display,
111    {
112        if let Some(channel) = &self.log_channel
113            && self.log_compile_info
114        {
115            let _ = channel.try_send(LogMessage::Compilation(arg.to_string()));
117        }
118    }
119
120    pub fn log_streaming<I: FnOnce() -> String, C: FnOnce(StreamingLogLevel) -> bool>(
122        &self,
123        cond: C,
124        format: I,
125    ) {
126        if let Some(channel) = &self.log_channel
127            && cond(self.log_streaming)
128        {
129            let _ = channel.try_send(LogMessage::Streaming(format()));
131        }
132    }
133
134    pub fn log_memory<I: FnOnce() -> String, C: FnOnce(MemoryLogLevel) -> bool>(
136        &self,
137        cond: C,
138        format: I,
139    ) {
140        if let Some(channel) = &self.log_channel
141            && cond(self.log_memory)
142        {
143            let _ = channel.try_send(LogMessage::Memory(format()));
145        }
146    }
147
148    pub fn register_execution(&self, name: impl Display) {
150        if let Some(channel) = &self.log_channel
151            && matches!(self.profile_level, Some(ProfileLevel::ExecutionOnly))
152        {
153            let _ = channel.try_send(LogMessage::Execution(name.to_string()));
155        }
156    }
157
158    pub fn register_profiled(&self, name: impl Display, duration: ProfileDuration) {
160        if let Some(channel) = &self.log_channel
161            && self.profile_level.is_some()
162        {
163            let _ = channel.try_send(LogMessage::Profile(name.to_string(), duration));
165        }
166    }
167
168    pub fn profile_summary(&self) {
170        if let Some(channel) = &self.log_channel
171            && self.profile_level.is_some()
172        {
173            let _ = channel.try_send(LogMessage::ProfileSummary);
175        }
176    }
177}
178
179struct AsyncLogger {
180    message: Receiver<LogMessage>,
181    logger: Logger,
182    profiled: Profiled,
183}
184
185impl AsyncLogger {
186    async fn process(mut self) {
187        while let Ok(msg) = self.message.recv().await {
188            match msg {
189                LogMessage::Compilation(msg) => {
190                    self.logger.log_compilation(&msg);
191                }
192                LogMessage::Streaming(msg) => {
193                    self.logger.log_streaming(&msg);
194                }
195                LogMessage::Memory(msg) => {
196                    self.logger.log_memory(&msg);
197                }
198                LogMessage::Profile(name, profile) => {
199                    let duration = profile.resolve().await.duration();
200                    self.profiled.update(&name, duration);
201                    self.logger
202                        .log_profiling(&format!("| {duration:<10?} | {name}"));
203                }
204                LogMessage::Execution(name) => {
205                    self.logger.log_profiling(&format!("Executing {name}"));
206                }
207                LogMessage::ProfileSummary => {
208                    if !self.profiled.is_empty() {
209                        self.logger.log_profiling(&self.profiled);
210                        self.profiled = Profiled::default();
211                    }
212                }
213            }
214        }
215    }
216}