use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::{Arc, OnceLock};
use rust_mcp_sdk::McpServer;
use rust_mcp_sdk::schema::LoggingMessageNotificationParams;
pub(crate) mod level {
pub const EMERGENCY: u8 = 0;
pub const ALERT: u8 = 1;
pub const CRITICAL: u8 = 2;
pub const ERROR: u8 = 3;
pub const WARNING: u8 = 4;
pub const NOTICE: u8 = 5;
pub const INFO: u8 = 6;
pub const DEBUG: u8 = 7;
}
fn logging_level_to_u8(level: rust_mcp_sdk::schema::LoggingLevel) -> u8 {
use rust_mcp_sdk::schema::LoggingLevel;
match level {
LoggingLevel::Emergency => level::EMERGENCY,
LoggingLevel::Alert => level::ALERT,
LoggingLevel::Critical => level::CRITICAL,
LoggingLevel::Error => level::ERROR,
LoggingLevel::Warning => level::WARNING,
LoggingLevel::Notice => level::NOTICE,
LoggingLevel::Info => level::INFO,
LoggingLevel::Debug => level::DEBUG,
}
}
fn u8_to_logging_level(v: u8) -> rust_mcp_sdk::schema::LoggingLevel {
use rust_mcp_sdk::schema::LoggingLevel;
match v {
level::EMERGENCY => LoggingLevel::Emergency,
level::ALERT => LoggingLevel::Alert,
level::CRITICAL => LoggingLevel::Critical,
level::ERROR => LoggingLevel::Error,
level::WARNING => LoggingLevel::Warning,
level::NOTICE => LoggingLevel::Notice,
level::DEBUG => LoggingLevel::Debug,
_ => LoggingLevel::Info,
}
}
pub(crate) struct McpLogger {
runtime: OnceLock<Arc<dyn McpServer>>,
min_level: AtomicU8,
}
impl McpLogger {
pub(crate) const fn new() -> Self {
Self {
runtime: OnceLock::new(),
min_level: AtomicU8::new(level::INFO),
}
}
pub(crate) fn init(&self, runtime: Arc<dyn McpServer>) {
let _ = self.runtime.set(runtime);
}
pub(crate) fn set_level(&self, level: rust_mcp_sdk::schema::LoggingLevel) {
self.min_level
.store(logging_level_to_u8(level), Ordering::Relaxed);
self.log(
level::NOTICE,
"mcp",
&format!("logging level set to {level:?}"),
None,
);
}
pub(crate) fn log(
&self,
severity: u8,
logger: &str,
message: &str,
data: Option<serde_json::Value>,
) {
match severity {
level::EMERGENCY | level::ALERT | level::CRITICAL | level::ERROR => {
tracing::error!(logger, "{}", message);
}
level::WARNING => {
tracing::warn!(logger, "{}", message);
}
level::NOTICE | level::INFO => {
tracing::info!(logger, "{}", message);
}
_ => {
tracing::debug!(logger, "{}", message);
}
}
let min = self.min_level.load(Ordering::Relaxed);
if severity > min {
return;
}
let Some(rt) = self.runtime.get() else {
return;
};
let level = u8_to_logging_level(severity);
let params = LoggingMessageNotificationParams {
level,
logger: Some(logger.to_string()),
data: data.unwrap_or_else(|| serde_json::Value::String(message.to_string())),
meta: None,
};
let rt = Arc::clone(rt);
tokio::spawn(async move {
if let Err(e) = rt.notify_log_message(params).await {
tracing::debug!("MCP log notification failed: {e}");
}
});
}
}
pub(crate) static LOGGER: McpLogger = McpLogger::new();
#[cfg(test)]
mod tests {
use super::*;
fn fresh_logger() -> McpLogger {
McpLogger::new()
}
#[test]
fn set_level_warning_rejects_info() {
let logger = fresh_logger();
logger.set_level(rust_mcp_sdk::schema::LoggingLevel::Warning);
let min = logger.min_level.load(Ordering::Relaxed);
assert!(
level::INFO > min,
"INFO should be filtered when min=warning"
);
}
#[test]
fn set_level_info_passes_warning() {
let logger = fresh_logger();
logger.set_level(rust_mcp_sdk::schema::LoggingLevel::Info);
let min = logger.min_level.load(Ordering::Relaxed);
assert!(level::WARNING <= min, "WARNING should pass when min=info");
}
#[test]
fn set_level_debug_passes_all() {
let logger = fresh_logger();
logger.set_level(rust_mcp_sdk::schema::LoggingLevel::Debug);
let min = logger.min_level.load(Ordering::Relaxed);
assert_eq!(min, level::DEBUG);
}
#[test]
fn level_round_trip_all_variants() {
use rust_mcp_sdk::schema::LoggingLevel;
let variants = [
LoggingLevel::Emergency,
LoggingLevel::Alert,
LoggingLevel::Critical,
LoggingLevel::Error,
LoggingLevel::Warning,
LoggingLevel::Notice,
LoggingLevel::Info,
LoggingLevel::Debug,
];
for variant in &variants {
let numeric = logging_level_to_u8(*variant);
let restored = u8_to_logging_level(numeric);
assert_eq!(
logging_level_to_u8(restored),
numeric,
"round-trip failed for {variant:?}"
);
}
}
#[test]
fn current_level_reflects_set_level() {
let logger = fresh_logger();
logger.set_level(rust_mcp_sdk::schema::LoggingLevel::Error);
let lvl = u8_to_logging_level(logger.min_level.load(Ordering::Relaxed));
assert_eq!(
logging_level_to_u8(lvl),
level::ERROR,
"stored level should reflect the last set_level call"
);
}
#[test]
fn log_with_data_builds_notification_params() {
let data = serde_json::json!({"url": "https://example.com", "status": 200});
let params = LoggingMessageNotificationParams {
level: rust_mcp_sdk::schema::LoggingLevel::Info,
logger: Some("fetch".to_string()),
data: data.clone(),
meta: None,
};
assert_eq!(params.data, data);
assert_eq!(params.logger.as_deref(), Some("fetch"));
}
}