Skip to main content

aster/mcp/
logging.rs

1//! MCP Logging Module
2//!
3//! This module provides logging functionality for MCP servers, including:
4//! - Log forwarding from server notifications to application logger (Requirements 8.4)
5//! - Configurable log levels per server (Requirements 8.5)
6//! - Structured log entries with server context
7//!
8//! # Example
9//!
10//! ```rust,ignore
11//! use aster::mcp::logging::{McpLogger, McpLogEntry};
12//! use aster::mcp::types::McpLogLevel;
13//!
14//! let logger = McpLogger::new();
15//! logger.set_server_log_level("my-server", McpLogLevel::Debug);
16//!
17//! // Log a message from a server
18//! logger.log(McpLogEntry {
19//!     server_name: "my-server".to_string(),
20//!     level: McpLogLevel::Info,
21//!     message: "Server started".to_string(),
22//!     data: None,
23//! });
24//! ```
25
26use serde::{Deserialize, Serialize};
27use std::collections::HashMap;
28use std::sync::Arc;
29use tokio::sync::RwLock;
30
31use crate::mcp::types::McpLogLevel;
32
33/// A log entry from an MCP server
34///
35/// This struct represents a log message received from an MCP server
36/// via the logging/message notification.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct McpLogEntry {
39    /// Name of the server that generated the log
40    pub server_name: String,
41    /// Log level
42    pub level: McpLogLevel,
43    /// Log message
44    pub message: String,
45    /// Optional structured data
46    pub data: Option<serde_json::Value>,
47    /// Optional logger name from the server
48    pub logger: Option<String>,
49}
50
51impl McpLogEntry {
52    /// Create a new log entry
53    pub fn new(
54        server_name: impl Into<String>,
55        level: McpLogLevel,
56        message: impl Into<String>,
57    ) -> Self {
58        Self {
59            server_name: server_name.into(),
60            level,
61            message: message.into(),
62            data: None,
63            logger: None,
64        }
65    }
66
67    /// Add structured data to the log entry
68    pub fn with_data(mut self, data: serde_json::Value) -> Self {
69        self.data = Some(data);
70        self
71    }
72
73    /// Add logger name to the log entry
74    pub fn with_logger(mut self, logger: impl Into<String>) -> Self {
75        self.logger = Some(logger.into());
76        self
77    }
78}
79
80/// Callback type for log entry handlers
81pub type LogCallback = Arc<dyn Fn(&McpLogEntry) + Send + Sync>;
82
83/// MCP Logger for handling server log notifications
84///
85/// This logger manages log levels per server and forwards log messages
86/// to the application's logging system (tracing).
87///
88/// # Requirements Coverage
89/// - 8.4: Forward server log notifications to application logger
90/// - 8.5: Support configurable log levels per MCP server
91pub struct McpLogger {
92    /// Log levels per server
93    server_levels: Arc<RwLock<HashMap<String, McpLogLevel>>>,
94    /// Default log level for servers without specific configuration
95    default_level: Arc<RwLock<McpLogLevel>>,
96    /// Custom log callbacks
97    callbacks: Arc<RwLock<Vec<LogCallback>>>,
98    /// Whether logging is enabled
99    enabled: Arc<RwLock<bool>>,
100}
101
102impl McpLogger {
103    /// Create a new MCP logger with default settings
104    pub fn new() -> Self {
105        Self {
106            server_levels: Arc::new(RwLock::new(HashMap::new())),
107            default_level: Arc::new(RwLock::new(McpLogLevel::Info)),
108            callbacks: Arc::new(RwLock::new(Vec::new())),
109            enabled: Arc::new(RwLock::new(true)),
110        }
111    }
112
113    /// Create a new MCP logger with a specific default level
114    pub fn with_default_level(level: McpLogLevel) -> Self {
115        Self {
116            server_levels: Arc::new(RwLock::new(HashMap::new())),
117            default_level: Arc::new(RwLock::new(level)),
118            callbacks: Arc::new(RwLock::new(Vec::new())),
119            enabled: Arc::new(RwLock::new(true)),
120        }
121    }
122
123    /// Set the log level for a specific server
124    ///
125    /// # Requirements: 8.5
126    pub async fn set_server_log_level(&self, server_name: &str, level: McpLogLevel) {
127        let mut levels = self.server_levels.write().await;
128        levels.insert(server_name.to_string(), level);
129    }
130
131    /// Get the log level for a specific server
132    pub async fn get_server_log_level(&self, server_name: &str) -> McpLogLevel {
133        let levels = self.server_levels.read().await;
134        levels
135            .get(server_name)
136            .copied()
137            .unwrap_or(*self.default_level.read().await)
138    }
139
140    /// Remove the log level configuration for a server (falls back to default)
141    pub async fn remove_server_log_level(&self, server_name: &str) {
142        let mut levels = self.server_levels.write().await;
143        levels.remove(server_name);
144    }
145
146    /// Set the default log level for servers without specific configuration
147    pub async fn set_default_level(&self, level: McpLogLevel) {
148        let mut default = self.default_level.write().await;
149        *default = level;
150    }
151
152    /// Get the default log level
153    pub async fn get_default_level(&self) -> McpLogLevel {
154        *self.default_level.read().await
155    }
156
157    /// Enable or disable logging
158    pub async fn set_enabled(&self, enabled: bool) {
159        let mut e = self.enabled.write().await;
160        *e = enabled;
161    }
162
163    /// Check if logging is enabled
164    pub async fn is_enabled(&self) -> bool {
165        *self.enabled.read().await
166    }
167
168    /// Register a callback for log entries
169    ///
170    /// Returns a function that can be called to unregister the callback.
171    pub async fn on_log(&self, callback: LogCallback) {
172        let mut callbacks = self.callbacks.write().await;
173        callbacks.push(callback);
174    }
175
176    /// Log an entry from an MCP server
177    ///
178    /// This method checks the configured log level for the server and
179    /// forwards the message to the application logger if appropriate.
180    ///
181    /// # Requirements: 8.4
182    pub async fn log(&self, entry: McpLogEntry) {
183        // Check if logging is enabled
184        if !*self.enabled.read().await {
185            return;
186        }
187
188        // Check if this message should be logged based on server's configured level
189        let server_level = self.get_server_log_level(&entry.server_name).await;
190        if !server_level.should_log(entry.level) {
191            return;
192        }
193
194        // Forward to tracing
195        self.forward_to_tracing(&entry);
196
197        // Call registered callbacks
198        let callbacks = self.callbacks.read().await;
199        for callback in callbacks.iter() {
200            callback(&entry);
201        }
202    }
203
204    /// Forward a log entry to the tracing system
205    fn forward_to_tracing(&self, entry: &McpLogEntry) {
206        let server = &entry.server_name;
207        let message = &entry.message;
208        let logger = entry.logger.as_deref().unwrap_or("mcp");
209
210        match entry.level {
211            McpLogLevel::Debug => {
212                if let Some(ref data) = entry.data {
213                    tracing::debug!(
214                        target: "mcp",
215                        server = %server,
216                        logger = %logger,
217                        data = %data,
218                        "{}", message
219                    );
220                } else {
221                    tracing::debug!(
222                        target: "mcp",
223                        server = %server,
224                        logger = %logger,
225                        "{}", message
226                    );
227                }
228            }
229            McpLogLevel::Info => {
230                if let Some(ref data) = entry.data {
231                    tracing::info!(
232                        target: "mcp",
233                        server = %server,
234                        logger = %logger,
235                        data = %data,
236                        "{}", message
237                    );
238                } else {
239                    tracing::info!(
240                        target: "mcp",
241                        server = %server,
242                        logger = %logger,
243                        "{}", message
244                    );
245                }
246            }
247            McpLogLevel::Warn => {
248                if let Some(ref data) = entry.data {
249                    tracing::warn!(
250                        target: "mcp",
251                        server = %server,
252                        logger = %logger,
253                        data = %data,
254                        "{}", message
255                    );
256                } else {
257                    tracing::warn!(
258                        target: "mcp",
259                        server = %server,
260                        logger = %logger,
261                        "{}", message
262                    );
263                }
264            }
265            McpLogLevel::Error => {
266                if let Some(ref data) = entry.data {
267                    tracing::error!(
268                        target: "mcp",
269                        server = %server,
270                        logger = %logger,
271                        data = %data,
272                        "{}", message
273                    );
274                } else {
275                    tracing::error!(
276                        target: "mcp",
277                        server = %server,
278                        logger = %logger,
279                        "{}", message
280                    );
281                }
282            }
283        }
284    }
285
286    /// Process a logging notification from an MCP server
287    ///
288    /// This method parses the notification params and logs the message.
289    /// The notification format follows the MCP logging/message specification.
290    ///
291    /// # Requirements: 8.4
292    pub async fn process_notification(&self, server_name: &str, params: &serde_json::Value) {
293        // Parse the notification params
294        let level = params
295            .get("level")
296            .and_then(|v| v.as_str())
297            .and_then(McpLogLevel::parse)
298            .unwrap_or(McpLogLevel::Info);
299
300        let message = params
301            .get("data")
302            .and_then(|v| v.as_str())
303            .unwrap_or("")
304            .to_string();
305
306        let logger = params
307            .get("logger")
308            .and_then(|v| v.as_str())
309            .map(|s| s.to_string());
310
311        let entry = McpLogEntry {
312            server_name: server_name.to_string(),
313            level,
314            message,
315            data: params.get("data").cloned(),
316            logger,
317        };
318
319        self.log(entry).await;
320    }
321}
322
323impl Default for McpLogger {
324    fn default() -> Self {
325        Self::new()
326    }
327}
328
329impl Clone for McpLogger {
330    fn clone(&self) -> Self {
331        Self {
332            server_levels: self.server_levels.clone(),
333            default_level: self.default_level.clone(),
334            callbacks: self.callbacks.clone(),
335            enabled: self.enabled.clone(),
336        }
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use std::sync::atomic::{AtomicUsize, Ordering};
344
345    #[tokio::test]
346    async fn test_logger_new() {
347        let logger = McpLogger::new();
348        assert!(logger.is_enabled().await);
349        assert_eq!(logger.get_default_level().await, McpLogLevel::Info);
350    }
351
352    #[tokio::test]
353    async fn test_logger_with_default_level() {
354        let logger = McpLogger::with_default_level(McpLogLevel::Debug);
355        assert_eq!(logger.get_default_level().await, McpLogLevel::Debug);
356    }
357
358    #[tokio::test]
359    async fn test_set_server_log_level() {
360        let logger = McpLogger::new();
361
362        // Default level should be Info
363        assert_eq!(
364            logger.get_server_log_level("test-server").await,
365            McpLogLevel::Info
366        );
367
368        // Set specific level
369        logger
370            .set_server_log_level("test-server", McpLogLevel::Debug)
371            .await;
372        assert_eq!(
373            logger.get_server_log_level("test-server").await,
374            McpLogLevel::Debug
375        );
376
377        // Other servers should still use default
378        assert_eq!(
379            logger.get_server_log_level("other-server").await,
380            McpLogLevel::Info
381        );
382    }
383
384    #[tokio::test]
385    async fn test_remove_server_log_level() {
386        let logger = McpLogger::new();
387
388        logger
389            .set_server_log_level("test-server", McpLogLevel::Debug)
390            .await;
391        assert_eq!(
392            logger.get_server_log_level("test-server").await,
393            McpLogLevel::Debug
394        );
395
396        logger.remove_server_log_level("test-server").await;
397        assert_eq!(
398            logger.get_server_log_level("test-server").await,
399            McpLogLevel::Info
400        );
401    }
402
403    #[tokio::test]
404    async fn test_set_enabled() {
405        let logger = McpLogger::new();
406
407        assert!(logger.is_enabled().await);
408
409        logger.set_enabled(false).await;
410        assert!(!logger.is_enabled().await);
411
412        logger.set_enabled(true).await;
413        assert!(logger.is_enabled().await);
414    }
415
416    #[tokio::test]
417    async fn test_log_callback() {
418        let logger = McpLogger::new();
419        let call_count = Arc::new(AtomicUsize::new(0));
420        let call_count_clone = call_count.clone();
421
422        logger
423            .on_log(Arc::new(move |_entry| {
424                call_count_clone.fetch_add(1, Ordering::SeqCst);
425            }))
426            .await;
427
428        let entry = McpLogEntry::new("test-server", McpLogLevel::Info, "Test message");
429        logger.log(entry).await;
430
431        assert_eq!(call_count.load(Ordering::SeqCst), 1);
432    }
433
434    #[tokio::test]
435    async fn test_log_level_filtering() {
436        let logger = McpLogger::new();
437        let call_count = Arc::new(AtomicUsize::new(0));
438        let call_count_clone = call_count.clone();
439
440        logger
441            .on_log(Arc::new(move |_entry| {
442                call_count_clone.fetch_add(1, Ordering::SeqCst);
443            }))
444            .await;
445
446        // Set server level to Warn
447        logger
448            .set_server_log_level("test-server", McpLogLevel::Warn)
449            .await;
450
451        // Debug message should be filtered
452        let debug_entry = McpLogEntry::new("test-server", McpLogLevel::Debug, "Debug message");
453        logger.log(debug_entry).await;
454        assert_eq!(call_count.load(Ordering::SeqCst), 0);
455
456        // Info message should be filtered
457        let info_entry = McpLogEntry::new("test-server", McpLogLevel::Info, "Info message");
458        logger.log(info_entry).await;
459        assert_eq!(call_count.load(Ordering::SeqCst), 0);
460
461        // Warn message should pass
462        let warn_entry = McpLogEntry::new("test-server", McpLogLevel::Warn, "Warn message");
463        logger.log(warn_entry).await;
464        assert_eq!(call_count.load(Ordering::SeqCst), 1);
465
466        // Error message should pass
467        let error_entry = McpLogEntry::new("test-server", McpLogLevel::Error, "Error message");
468        logger.log(error_entry).await;
469        assert_eq!(call_count.load(Ordering::SeqCst), 2);
470    }
471
472    #[tokio::test]
473    async fn test_log_disabled() {
474        let logger = McpLogger::new();
475        let call_count = Arc::new(AtomicUsize::new(0));
476        let call_count_clone = call_count.clone();
477
478        logger
479            .on_log(Arc::new(move |_entry| {
480                call_count_clone.fetch_add(1, Ordering::SeqCst);
481            }))
482            .await;
483
484        // Disable logging
485        logger.set_enabled(false).await;
486
487        let entry = McpLogEntry::new("test-server", McpLogLevel::Info, "Test message");
488        logger.log(entry).await;
489
490        // Callback should not be called
491        assert_eq!(call_count.load(Ordering::SeqCst), 0);
492    }
493
494    #[tokio::test]
495    async fn test_process_notification() {
496        let logger = McpLogger::new();
497        let call_count = Arc::new(AtomicUsize::new(0));
498        let received_message = Arc::new(RwLock::new(String::new()));
499        let call_count_clone = call_count.clone();
500        let received_message_clone = received_message.clone();
501
502        logger
503            .on_log(Arc::new(move |entry| {
504                call_count_clone.fetch_add(1, Ordering::SeqCst);
505                let msg = entry.message.clone();
506                let rm = received_message_clone.clone();
507                tokio::spawn(async move {
508                    let mut m = rm.write().await;
509                    *m = msg;
510                });
511            }))
512            .await;
513
514        let params = serde_json::json!({
515            "level": "info",
516            "data": "Test notification message",
517            "logger": "test-logger"
518        });
519
520        logger.process_notification("test-server", &params).await;
521
522        assert_eq!(call_count.load(Ordering::SeqCst), 1);
523    }
524
525    #[test]
526    fn test_log_entry_new() {
527        let entry = McpLogEntry::new("server", McpLogLevel::Info, "message");
528        assert_eq!(entry.server_name, "server");
529        assert_eq!(entry.level, McpLogLevel::Info);
530        assert_eq!(entry.message, "message");
531        assert!(entry.data.is_none());
532        assert!(entry.logger.is_none());
533    }
534
535    #[test]
536    fn test_log_entry_with_data() {
537        let entry = McpLogEntry::new("server", McpLogLevel::Info, "message")
538            .with_data(serde_json::json!({"key": "value"}));
539        assert!(entry.data.is_some());
540    }
541
542    #[test]
543    fn test_log_entry_with_logger() {
544        let entry =
545            McpLogEntry::new("server", McpLogLevel::Info, "message").with_logger("custom-logger");
546        assert_eq!(entry.logger, Some("custom-logger".to_string()));
547    }
548}