1use rmcp::{
10 RoleServer,
11 model::{LoggingLevel, LoggingMessageNotificationParam},
12 service::Peer,
13};
14use serde_json::{Value, json};
15use std::sync::{
16 Arc,
17 atomic::{AtomicU8, Ordering},
18};
19use tracing::Level;
20
21pub struct LogLevelFilter(AtomicU8);
26
27impl LogLevelFilter {
28 pub fn new(level: LoggingLevel) -> Self {
30 Self(AtomicU8::new(level_to_u8(level)))
31 }
32
33 pub fn get(&self) -> LoggingLevel {
35 u8_to_level(self.0.load(Ordering::Relaxed))
36 }
37
38 pub fn set(&self, level: LoggingLevel) {
40 self.0.store(level_to_u8(level), Ordering::Relaxed);
41 }
42
43 pub fn should_log(&self, level: LoggingLevel) -> bool {
45 level_to_u8(level) >= self.0.load(Ordering::Relaxed)
46 }
47}
48
49impl Default for LogLevelFilter {
50 fn default() -> Self {
51 Self::new(LoggingLevel::Debug)
52 }
53}
54
55fn level_to_u8(level: LoggingLevel) -> u8 {
57 match level {
58 LoggingLevel::Debug => 0,
59 LoggingLevel::Info => 1,
60 LoggingLevel::Notice => 2,
61 LoggingLevel::Warning => 3,
62 LoggingLevel::Error => 4,
63 LoggingLevel::Critical => 5,
64 LoggingLevel::Alert => 6,
65 LoggingLevel::Emergency => 7,
66 }
67}
68
69fn u8_to_level(val: u8) -> LoggingLevel {
71 match val {
72 0 => LoggingLevel::Debug,
73 1 => LoggingLevel::Info,
74 2 => LoggingLevel::Notice,
75 3 => LoggingLevel::Warning,
76 4 => LoggingLevel::Error,
77 5 => LoggingLevel::Critical,
78 6 => LoggingLevel::Alert,
79 7 => LoggingLevel::Emergency,
80 _ => LoggingLevel::Debug,
81 }
82}
83
84pub fn logging_level_to_tracing(level: LoggingLevel) -> Level {
86 match level {
87 LoggingLevel::Debug => Level::DEBUG,
88 LoggingLevel::Info | LoggingLevel::Notice => Level::INFO,
89 LoggingLevel::Warning => Level::WARN,
90 LoggingLevel::Error
91 | LoggingLevel::Critical
92 | LoggingLevel::Alert
93 | LoggingLevel::Emergency => Level::ERROR,
94 }
95}
96
97#[derive(Clone)]
103pub struct Logger {
104 peer: Option<Peer<RoleServer>>,
106 level_filter: Arc<LogLevelFilter>,
108 name: Option<String>,
110}
111
112impl Logger {
113 pub fn new() -> Self {
115 Self {
116 peer: None,
117 level_filter: Arc::new(LogLevelFilter::default()),
118 name: None,
119 }
120 }
121
122 pub fn with_peer(mut self, peer: Peer<RoleServer>) -> Self {
124 self.peer = Some(peer);
125 self
126 }
127
128 pub fn with_level_filter(mut self, filter: Arc<LogLevelFilter>) -> Self {
130 self.level_filter = filter;
131 self
132 }
133
134 pub fn with_name(mut self, name: impl Into<String>) -> Self {
136 self.name = Some(name.into());
137 self
138 }
139
140 pub fn log(&self, level: LoggingLevel, message: &str, data: Option<Value>) {
142 if !self.level_filter.should_log(level) {
143 return;
144 }
145
146 let tracing_level = logging_level_to_tracing(level);
148 match tracing_level {
149 Level::ERROR => {
150 if let Some(ref name) = self.name {
151 tracing::error!(logger = %name, "{}", message);
152 } else {
153 tracing::error!("{}", message);
154 }
155 }
156 Level::WARN => {
157 if let Some(ref name) = self.name {
158 tracing::warn!(logger = %name, "{}", message);
159 } else {
160 tracing::warn!("{}", message);
161 }
162 }
163 Level::INFO => {
164 if let Some(ref name) = self.name {
165 tracing::info!(logger = %name, "{}", message);
166 } else {
167 tracing::info!("{}", message);
168 }
169 }
170 Level::DEBUG => {
171 if let Some(ref name) = self.name {
172 tracing::debug!(logger = %name, "{}", message);
173 } else {
174 tracing::debug!("{}", message);
175 }
176 }
177 Level::TRACE => {
178 if let Some(ref name) = self.name {
179 tracing::trace!(logger = %name, "{}", message);
180 } else {
181 tracing::trace!("{}", message);
182 }
183 }
184 }
185
186 if let Some(ref peer) = self.peer {
188 let param = LoggingMessageNotificationParam {
189 level,
190 logger: self.name.clone(),
191 data: data.unwrap_or_else(|| json!({ "message": message })),
192 };
193 let peer = peer.clone();
194 tokio::spawn(async move {
195 let _ = peer.notify_logging_message(param).await;
196 });
197 }
198 }
199
200 pub fn log_with_data(&self, level: LoggingLevel, message: &str, data: Value) {
202 self.log(level, message, Some(data));
203 }
204
205 pub fn debug(&self, msg: &str) {
209 self.log(LoggingLevel::Debug, msg, None);
210 }
211
212 pub fn info(&self, msg: &str) {
214 self.log(LoggingLevel::Info, msg, None);
215 }
216
217 pub fn notice(&self, msg: &str) {
219 self.log(LoggingLevel::Notice, msg, None);
220 }
221
222 pub fn warning(&self, msg: &str) {
224 self.log(LoggingLevel::Warning, msg, None);
225 }
226
227 pub fn error(&self, msg: &str) {
229 self.log(LoggingLevel::Error, msg, None);
230 }
231
232 pub fn critical(&self, msg: &str) {
234 self.log(LoggingLevel::Critical, msg, None);
235 }
236
237 pub fn alert(&self, msg: &str) {
239 self.log(LoggingLevel::Alert, msg, None);
240 }
241
242 pub fn emergency(&self, msg: &str) {
244 self.log(LoggingLevel::Emergency, msg, None);
245 }
246}
247
248impl Default for Logger {
249 fn default() -> Self {
250 Self::new()
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn test_level_filter() {
260 let filter = LogLevelFilter::new(LoggingLevel::Warning);
261
262 assert!(!filter.should_log(LoggingLevel::Debug));
264 assert!(!filter.should_log(LoggingLevel::Info));
265 assert!(!filter.should_log(LoggingLevel::Notice));
266
267 assert!(filter.should_log(LoggingLevel::Warning));
269 assert!(filter.should_log(LoggingLevel::Error));
270 assert!(filter.should_log(LoggingLevel::Critical));
271 assert!(filter.should_log(LoggingLevel::Alert));
272 assert!(filter.should_log(LoggingLevel::Emergency));
273 }
274
275 #[test]
276 fn test_level_filter_update() {
277 let filter = LogLevelFilter::new(LoggingLevel::Debug);
278 assert!(filter.should_log(LoggingLevel::Debug));
279
280 filter.set(LoggingLevel::Error);
281 assert!(!filter.should_log(LoggingLevel::Debug));
282 assert!(!filter.should_log(LoggingLevel::Warning));
283 assert!(filter.should_log(LoggingLevel::Error));
284 }
285
286 #[test]
287 fn test_logging_level_to_tracing() {
288 assert_eq!(logging_level_to_tracing(LoggingLevel::Debug), Level::DEBUG);
289 assert_eq!(logging_level_to_tracing(LoggingLevel::Info), Level::INFO);
290 assert_eq!(logging_level_to_tracing(LoggingLevel::Notice), Level::INFO);
291 assert_eq!(logging_level_to_tracing(LoggingLevel::Warning), Level::WARN);
292 assert_eq!(logging_level_to_tracing(LoggingLevel::Error), Level::ERROR);
293 assert_eq!(
294 logging_level_to_tracing(LoggingLevel::Critical),
295 Level::ERROR
296 );
297 assert_eq!(logging_level_to_tracing(LoggingLevel::Alert), Level::ERROR);
298 assert_eq!(
299 logging_level_to_tracing(LoggingLevel::Emergency),
300 Level::ERROR
301 );
302 }
303
304 #[test]
305 fn test_level_roundtrip() {
306 for level in [
307 LoggingLevel::Debug,
308 LoggingLevel::Info,
309 LoggingLevel::Notice,
310 LoggingLevel::Warning,
311 LoggingLevel::Error,
312 LoggingLevel::Critical,
313 LoggingLevel::Alert,
314 LoggingLevel::Emergency,
315 ] {
316 let filter = LogLevelFilter::new(level);
317 assert_eq!(filter.get(), level);
318 }
319 }
320}