1use serde::{Deserialize, Serialize};
27use std::collections::HashMap;
28use std::sync::Arc;
29use tokio::sync::RwLock;
30
31use crate::mcp::types::McpLogLevel;
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct McpLogEntry {
39 pub server_name: String,
41 pub level: McpLogLevel,
43 pub message: String,
45 pub data: Option<serde_json::Value>,
47 pub logger: Option<String>,
49}
50
51impl McpLogEntry {
52 pub fn new(
54 server_name: impl Into<String>,
55 level: McpLogLevel,
56 message: impl Into<String>,
57 ) -> Self {
58 Self {
59 server_name: server_name.into(),
60 level,
61 message: message.into(),
62 data: None,
63 logger: None,
64 }
65 }
66
67 pub fn with_data(mut self, data: serde_json::Value) -> Self {
69 self.data = Some(data);
70 self
71 }
72
73 pub fn with_logger(mut self, logger: impl Into<String>) -> Self {
75 self.logger = Some(logger.into());
76 self
77 }
78}
79
80pub type LogCallback = Arc<dyn Fn(&McpLogEntry) + Send + Sync>;
82
83pub struct McpLogger {
92 server_levels: Arc<RwLock<HashMap<String, McpLogLevel>>>,
94 default_level: Arc<RwLock<McpLogLevel>>,
96 callbacks: Arc<RwLock<Vec<LogCallback>>>,
98 enabled: Arc<RwLock<bool>>,
100}
101
102impl McpLogger {
103 pub fn new() -> Self {
105 Self {
106 server_levels: Arc::new(RwLock::new(HashMap::new())),
107 default_level: Arc::new(RwLock::new(McpLogLevel::Info)),
108 callbacks: Arc::new(RwLock::new(Vec::new())),
109 enabled: Arc::new(RwLock::new(true)),
110 }
111 }
112
113 pub fn with_default_level(level: McpLogLevel) -> Self {
115 Self {
116 server_levels: Arc::new(RwLock::new(HashMap::new())),
117 default_level: Arc::new(RwLock::new(level)),
118 callbacks: Arc::new(RwLock::new(Vec::new())),
119 enabled: Arc::new(RwLock::new(true)),
120 }
121 }
122
123 pub async fn set_server_log_level(&self, server_name: &str, level: McpLogLevel) {
127 let mut levels = self.server_levels.write().await;
128 levels.insert(server_name.to_string(), level);
129 }
130
131 pub async fn get_server_log_level(&self, server_name: &str) -> McpLogLevel {
133 let levels = self.server_levels.read().await;
134 levels
135 .get(server_name)
136 .copied()
137 .unwrap_or(*self.default_level.read().await)
138 }
139
140 pub async fn remove_server_log_level(&self, server_name: &str) {
142 let mut levels = self.server_levels.write().await;
143 levels.remove(server_name);
144 }
145
146 pub async fn set_default_level(&self, level: McpLogLevel) {
148 let mut default = self.default_level.write().await;
149 *default = level;
150 }
151
152 pub async fn get_default_level(&self) -> McpLogLevel {
154 *self.default_level.read().await
155 }
156
157 pub async fn set_enabled(&self, enabled: bool) {
159 let mut e = self.enabled.write().await;
160 *e = enabled;
161 }
162
163 pub async fn is_enabled(&self) -> bool {
165 *self.enabled.read().await
166 }
167
168 pub async fn on_log(&self, callback: LogCallback) {
172 let mut callbacks = self.callbacks.write().await;
173 callbacks.push(callback);
174 }
175
176 pub async fn log(&self, entry: McpLogEntry) {
183 if !*self.enabled.read().await {
185 return;
186 }
187
188 let server_level = self.get_server_log_level(&entry.server_name).await;
190 if !server_level.should_log(entry.level) {
191 return;
192 }
193
194 self.forward_to_tracing(&entry);
196
197 let callbacks = self.callbacks.read().await;
199 for callback in callbacks.iter() {
200 callback(&entry);
201 }
202 }
203
204 fn forward_to_tracing(&self, entry: &McpLogEntry) {
206 let server = &entry.server_name;
207 let message = &entry.message;
208 let logger = entry.logger.as_deref().unwrap_or("mcp");
209
210 match entry.level {
211 McpLogLevel::Debug => {
212 if let Some(ref data) = entry.data {
213 tracing::debug!(
214 target: "mcp",
215 server = %server,
216 logger = %logger,
217 data = %data,
218 "{}", message
219 );
220 } else {
221 tracing::debug!(
222 target: "mcp",
223 server = %server,
224 logger = %logger,
225 "{}", message
226 );
227 }
228 }
229 McpLogLevel::Info => {
230 if let Some(ref data) = entry.data {
231 tracing::info!(
232 target: "mcp",
233 server = %server,
234 logger = %logger,
235 data = %data,
236 "{}", message
237 );
238 } else {
239 tracing::info!(
240 target: "mcp",
241 server = %server,
242 logger = %logger,
243 "{}", message
244 );
245 }
246 }
247 McpLogLevel::Warn => {
248 if let Some(ref data) = entry.data {
249 tracing::warn!(
250 target: "mcp",
251 server = %server,
252 logger = %logger,
253 data = %data,
254 "{}", message
255 );
256 } else {
257 tracing::warn!(
258 target: "mcp",
259 server = %server,
260 logger = %logger,
261 "{}", message
262 );
263 }
264 }
265 McpLogLevel::Error => {
266 if let Some(ref data) = entry.data {
267 tracing::error!(
268 target: "mcp",
269 server = %server,
270 logger = %logger,
271 data = %data,
272 "{}", message
273 );
274 } else {
275 tracing::error!(
276 target: "mcp",
277 server = %server,
278 logger = %logger,
279 "{}", message
280 );
281 }
282 }
283 }
284 }
285
286 pub async fn process_notification(&self, server_name: &str, params: &serde_json::Value) {
293 let level = params
295 .get("level")
296 .and_then(|v| v.as_str())
297 .and_then(McpLogLevel::parse)
298 .unwrap_or(McpLogLevel::Info);
299
300 let message = params
301 .get("data")
302 .and_then(|v| v.as_str())
303 .unwrap_or("")
304 .to_string();
305
306 let logger = params
307 .get("logger")
308 .and_then(|v| v.as_str())
309 .map(|s| s.to_string());
310
311 let entry = McpLogEntry {
312 server_name: server_name.to_string(),
313 level,
314 message,
315 data: params.get("data").cloned(),
316 logger,
317 };
318
319 self.log(entry).await;
320 }
321}
322
323impl Default for McpLogger {
324 fn default() -> Self {
325 Self::new()
326 }
327}
328
329impl Clone for McpLogger {
330 fn clone(&self) -> Self {
331 Self {
332 server_levels: self.server_levels.clone(),
333 default_level: self.default_level.clone(),
334 callbacks: self.callbacks.clone(),
335 enabled: self.enabled.clone(),
336 }
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use std::sync::atomic::{AtomicUsize, Ordering};
344
345 #[tokio::test]
346 async fn test_logger_new() {
347 let logger = McpLogger::new();
348 assert!(logger.is_enabled().await);
349 assert_eq!(logger.get_default_level().await, McpLogLevel::Info);
350 }
351
352 #[tokio::test]
353 async fn test_logger_with_default_level() {
354 let logger = McpLogger::with_default_level(McpLogLevel::Debug);
355 assert_eq!(logger.get_default_level().await, McpLogLevel::Debug);
356 }
357
358 #[tokio::test]
359 async fn test_set_server_log_level() {
360 let logger = McpLogger::new();
361
362 assert_eq!(
364 logger.get_server_log_level("test-server").await,
365 McpLogLevel::Info
366 );
367
368 logger
370 .set_server_log_level("test-server", McpLogLevel::Debug)
371 .await;
372 assert_eq!(
373 logger.get_server_log_level("test-server").await,
374 McpLogLevel::Debug
375 );
376
377 assert_eq!(
379 logger.get_server_log_level("other-server").await,
380 McpLogLevel::Info
381 );
382 }
383
384 #[tokio::test]
385 async fn test_remove_server_log_level() {
386 let logger = McpLogger::new();
387
388 logger
389 .set_server_log_level("test-server", McpLogLevel::Debug)
390 .await;
391 assert_eq!(
392 logger.get_server_log_level("test-server").await,
393 McpLogLevel::Debug
394 );
395
396 logger.remove_server_log_level("test-server").await;
397 assert_eq!(
398 logger.get_server_log_level("test-server").await,
399 McpLogLevel::Info
400 );
401 }
402
403 #[tokio::test]
404 async fn test_set_enabled() {
405 let logger = McpLogger::new();
406
407 assert!(logger.is_enabled().await);
408
409 logger.set_enabled(false).await;
410 assert!(!logger.is_enabled().await);
411
412 logger.set_enabled(true).await;
413 assert!(logger.is_enabled().await);
414 }
415
416 #[tokio::test]
417 async fn test_log_callback() {
418 let logger = McpLogger::new();
419 let call_count = Arc::new(AtomicUsize::new(0));
420 let call_count_clone = call_count.clone();
421
422 logger
423 .on_log(Arc::new(move |_entry| {
424 call_count_clone.fetch_add(1, Ordering::SeqCst);
425 }))
426 .await;
427
428 let entry = McpLogEntry::new("test-server", McpLogLevel::Info, "Test message");
429 logger.log(entry).await;
430
431 assert_eq!(call_count.load(Ordering::SeqCst), 1);
432 }
433
434 #[tokio::test]
435 async fn test_log_level_filtering() {
436 let logger = McpLogger::new();
437 let call_count = Arc::new(AtomicUsize::new(0));
438 let call_count_clone = call_count.clone();
439
440 logger
441 .on_log(Arc::new(move |_entry| {
442 call_count_clone.fetch_add(1, Ordering::SeqCst);
443 }))
444 .await;
445
446 logger
448 .set_server_log_level("test-server", McpLogLevel::Warn)
449 .await;
450
451 let debug_entry = McpLogEntry::new("test-server", McpLogLevel::Debug, "Debug message");
453 logger.log(debug_entry).await;
454 assert_eq!(call_count.load(Ordering::SeqCst), 0);
455
456 let info_entry = McpLogEntry::new("test-server", McpLogLevel::Info, "Info message");
458 logger.log(info_entry).await;
459 assert_eq!(call_count.load(Ordering::SeqCst), 0);
460
461 let warn_entry = McpLogEntry::new("test-server", McpLogLevel::Warn, "Warn message");
463 logger.log(warn_entry).await;
464 assert_eq!(call_count.load(Ordering::SeqCst), 1);
465
466 let error_entry = McpLogEntry::new("test-server", McpLogLevel::Error, "Error message");
468 logger.log(error_entry).await;
469 assert_eq!(call_count.load(Ordering::SeqCst), 2);
470 }
471
472 #[tokio::test]
473 async fn test_log_disabled() {
474 let logger = McpLogger::new();
475 let call_count = Arc::new(AtomicUsize::new(0));
476 let call_count_clone = call_count.clone();
477
478 logger
479 .on_log(Arc::new(move |_entry| {
480 call_count_clone.fetch_add(1, Ordering::SeqCst);
481 }))
482 .await;
483
484 logger.set_enabled(false).await;
486
487 let entry = McpLogEntry::new("test-server", McpLogLevel::Info, "Test message");
488 logger.log(entry).await;
489
490 assert_eq!(call_count.load(Ordering::SeqCst), 0);
492 }
493
494 #[tokio::test]
495 async fn test_process_notification() {
496 let logger = McpLogger::new();
497 let call_count = Arc::new(AtomicUsize::new(0));
498 let received_message = Arc::new(RwLock::new(String::new()));
499 let call_count_clone = call_count.clone();
500 let received_message_clone = received_message.clone();
501
502 logger
503 .on_log(Arc::new(move |entry| {
504 call_count_clone.fetch_add(1, Ordering::SeqCst);
505 let msg = entry.message.clone();
506 let rm = received_message_clone.clone();
507 tokio::spawn(async move {
508 let mut m = rm.write().await;
509 *m = msg;
510 });
511 }))
512 .await;
513
514 let params = serde_json::json!({
515 "level": "info",
516 "data": "Test notification message",
517 "logger": "test-logger"
518 });
519
520 logger.process_notification("test-server", ¶ms).await;
521
522 assert_eq!(call_count.load(Ordering::SeqCst), 1);
523 }
524
525 #[test]
526 fn test_log_entry_new() {
527 let entry = McpLogEntry::new("server", McpLogLevel::Info, "message");
528 assert_eq!(entry.server_name, "server");
529 assert_eq!(entry.level, McpLogLevel::Info);
530 assert_eq!(entry.message, "message");
531 assert!(entry.data.is_none());
532 assert!(entry.logger.is_none());
533 }
534
535 #[test]
536 fn test_log_entry_with_data() {
537 let entry = McpLogEntry::new("server", McpLogLevel::Info, "message")
538 .with_data(serde_json::json!({"key": "value"}));
539 assert!(entry.data.is_some());
540 }
541
542 #[test]
543 fn test_log_entry_with_logger() {
544 let entry =
545 McpLogEntry::new("server", McpLogLevel::Info, "message").with_logger("custom-logger");
546 assert_eq!(entry.logger, Some("custom-logger".to_string()));
547 }
548}