use std::convert::Infallible;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;
use tower::Layer;
use tower_service::Service;
use tracing::{Instrument, Level, Span};
use crate::protocol::McpRequest;
use crate::router::{RouterRequest, RouterResponse};
#[derive(Debug, Clone, Copy)]
pub struct McpTracingLayer {
level: Level,
}
impl Default for McpTracingLayer {
fn default() -> Self {
Self::new()
}
}
impl McpTracingLayer {
pub fn new() -> Self {
Self { level: Level::INFO }
}
pub fn level(mut self, level: Level) -> Self {
self.level = level;
self
}
}
impl<S> Layer<S> for McpTracingLayer {
type Service = McpTracingService<S>;
fn layer(&self, inner: S) -> Self::Service {
McpTracingService {
inner,
level: self.level,
}
}
}
#[derive(Debug, Clone)]
pub struct McpTracingService<S> {
inner: S,
level: Level,
}
impl<S> Service<RouterRequest> for McpTracingService<S>
where
S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
+ Clone
+ Send
+ 'static,
S::Future: Send,
{
type Response = RouterResponse;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: RouterRequest) -> Self::Future {
let method = req.inner.method_name().to_string();
let request_id = format!("{:?}", req.id);
let (operation_name, operation_target) = extract_operation_details(&req.inner);
let span = create_span(
self.level,
&method,
&request_id,
operation_name,
operation_target,
);
let start = Instant::now();
let fut = self.inner.call(req);
let level = self.level;
Box::pin(
async move {
let result = fut.await;
let duration = start.elapsed();
match &result {
Ok(response) => {
let duration_ms = duration.as_secs_f64() * 1000.0;
match &response.inner {
Ok(_) => {
log_success(level, &method, duration_ms);
}
Err(err) => {
tracing::warn!(
method = %method,
error_code = err.code,
error_message = %err.message,
duration_ms = duration_ms,
"MCP request failed"
);
}
}
}
Err(_) => {
tracing::error!(method = %method, "MCP request error (infallible)");
}
}
result
}
.instrument(span),
)
}
}
pub(crate) fn extract_operation_details(
req: &McpRequest,
) -> (Option<&'static str>, Option<String>) {
match req {
McpRequest::CallTool(params) => (Some("tool"), Some(params.name.clone())),
McpRequest::ReadResource(params) => (Some("resource"), Some(params.uri.clone())),
McpRequest::GetPrompt(params) => (Some("prompt"), Some(params.name.clone())),
McpRequest::ListTools(_) => (Some("list"), Some("tools".to_string())),
McpRequest::ListResources(_) => (Some("list"), Some("resources".to_string())),
McpRequest::ListResourceTemplates(_) => {
(Some("list"), Some("resource_templates".to_string()))
}
McpRequest::ListPrompts(_) => (Some("list"), Some("prompts".to_string())),
McpRequest::SubscribeResource(params) => (Some("subscribe"), Some(params.uri.clone())),
McpRequest::UnsubscribeResource(params) => (Some("unsubscribe"), Some(params.uri.clone())),
McpRequest::ListTasks(_) => (Some("list"), Some("tasks".to_string())),
McpRequest::GetTaskInfo(params) => (Some("task"), Some(params.task_id.clone())),
McpRequest::GetTaskResult(params) => (Some("task_result"), Some(params.task_id.clone())),
McpRequest::CancelTask(params) => (Some("cancel"), Some(params.task_id.clone())),
McpRequest::Complete(params) => {
let ref_type = match ¶ms.reference {
crate::protocol::CompletionReference::Resource { uri } => {
format!("resource:{}", uri)
}
crate::protocol::CompletionReference::Prompt { name } => {
format!("prompt:{}", name)
}
_ => "unknown".to_string(),
};
(Some("complete"), Some(ref_type))
}
McpRequest::SetLoggingLevel(params) => {
(Some("logging"), Some(format!("{:?}", params.level)))
}
McpRequest::Initialize(_) => (Some("init"), None),
McpRequest::Ping => (Some("ping"), None),
McpRequest::Unknown { method, .. } => (Some("unknown"), Some(method.clone())),
_ => (Some("unknown"), None),
}
}
fn create_span(
level: Level,
method: &str,
request_id: &str,
operation_name: Option<&str>,
operation_target: Option<String>,
) -> Span {
match level {
Level::TRACE => tracing::trace_span!(
"mcp_request",
method = %method,
request_id = %request_id,
operation = operation_name,
target = operation_target.as_deref(),
),
Level::DEBUG => tracing::debug_span!(
"mcp_request",
method = %method,
request_id = %request_id,
operation = operation_name,
target = operation_target.as_deref(),
),
Level::INFO => tracing::info_span!(
"mcp_request",
method = %method,
request_id = %request_id,
operation = operation_name,
target = operation_target.as_deref(),
),
Level::WARN => tracing::warn_span!(
"mcp_request",
method = %method,
request_id = %request_id,
operation = operation_name,
target = operation_target.as_deref(),
),
Level::ERROR => tracing::error_span!(
"mcp_request",
method = %method,
request_id = %request_id,
operation = operation_name,
target = operation_target.as_deref(),
),
}
}
fn log_success(level: Level, method: &str, duration_ms: f64) {
match level {
Level::TRACE => {
tracing::trace!(method = %method, duration_ms = duration_ms, "MCP request completed")
}
Level::DEBUG => {
tracing::debug!(method = %method, duration_ms = duration_ms, "MCP request completed")
}
Level::INFO => {
tracing::info!(method = %method, duration_ms = duration_ms, "MCP request completed")
}
Level::WARN => {
tracing::warn!(method = %method, duration_ms = duration_ms, "MCP request completed")
}
Level::ERROR => {
tracing::error!(method = %method, duration_ms = duration_ms, "MCP request completed")
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_layer_creation() {
let layer = McpTracingLayer::new();
assert_eq!(layer.level, Level::INFO);
let layer = McpTracingLayer::new().level(Level::DEBUG);
assert_eq!(layer.level, Level::DEBUG);
}
#[test]
fn test_extract_operation_details() {
use crate::protocol::{CallToolParams, GetPromptParams, ReadResourceParams};
use serde_json::Value;
use std::collections::HashMap;
let req = McpRequest::CallTool(CallToolParams {
name: "my_tool".to_string(),
arguments: Value::Null,
meta: None,
task: None,
});
let (name, target) = extract_operation_details(&req);
assert_eq!(name, Some("tool"));
assert_eq!(target, Some("my_tool".to_string()));
let req = McpRequest::ReadResource(ReadResourceParams {
uri: "file:///test.txt".to_string(),
meta: None,
});
let (name, target) = extract_operation_details(&req);
assert_eq!(name, Some("resource"));
assert_eq!(target, Some("file:///test.txt".to_string()));
let req = McpRequest::GetPrompt(GetPromptParams {
name: "my_prompt".to_string(),
arguments: HashMap::new(),
meta: None,
});
let (name, target) = extract_operation_details(&req);
assert_eq!(name, Some("prompt"));
assert_eq!(target, Some("my_prompt".to_string()));
let req = McpRequest::Ping;
let (name, target) = extract_operation_details(&req);
assert_eq!(name, Some("ping"));
assert_eq!(target, None);
}
#[test]
fn test_extract_operation_details_list_operations() {
use crate::protocol::{
ListPromptsParams, ListResourceTemplatesParams, ListResourcesParams, ListToolsParams,
};
let (name, target) = extract_operation_details(&McpRequest::ListTools(ListToolsParams {
cursor: None,
meta: None,
}));
assert_eq!(name, Some("list"));
assert_eq!(target, Some("tools".to_string()));
let (name, target) =
extract_operation_details(&McpRequest::ListResources(ListResourcesParams {
cursor: None,
meta: None,
}));
assert_eq!(name, Some("list"));
assert_eq!(target, Some("resources".to_string()));
let (name, target) = extract_operation_details(&McpRequest::ListResourceTemplates(
ListResourceTemplatesParams {
cursor: None,
meta: None,
},
));
assert_eq!(name, Some("list"));
assert_eq!(target, Some("resource_templates".to_string()));
let (name, target) =
extract_operation_details(&McpRequest::ListPrompts(ListPromptsParams {
cursor: None,
meta: None,
}));
assert_eq!(name, Some("list"));
assert_eq!(target, Some("prompts".to_string()));
}
#[test]
fn test_extract_operation_details_initialize() {
use crate::protocol::{ClientCapabilities, Implementation, InitializeParams};
let req = McpRequest::Initialize(InitializeParams {
protocol_version: "2025-11-25".to_string(),
capabilities: ClientCapabilities::default(),
client_info: Implementation {
name: "test".to_string(),
version: "1.0".to_string(),
..Default::default()
},
meta: None,
});
let (name, target) = extract_operation_details(&req);
assert_eq!(name, Some("init"));
assert_eq!(target, None);
}
#[test]
fn test_extract_operation_details_subscribe() {
use crate::protocol::SubscribeResourceParams;
let req = McpRequest::SubscribeResource(SubscribeResourceParams {
uri: "file:///watched.txt".to_string(),
meta: None,
});
let (name, target) = extract_operation_details(&req);
assert_eq!(name, Some("subscribe"));
assert_eq!(target, Some("file:///watched.txt".to_string()));
}
#[test]
fn test_extract_operation_details_logging_level() {
use crate::protocol::{LogLevel, SetLogLevelParams};
let req = McpRequest::SetLoggingLevel(SetLogLevelParams {
level: LogLevel::Debug,
meta: None,
});
let (name, target) = extract_operation_details(&req);
assert_eq!(name, Some("logging"));
assert!(target.is_some());
}
#[test]
fn test_extract_operation_details_completion() {
use crate::protocol::{CompleteParams, CompletionArgument, CompletionReference};
let req = McpRequest::Complete(CompleteParams {
reference: CompletionReference::Prompt {
name: "my-prompt".to_string(),
},
argument: CompletionArgument::new("arg1", "val"),
context: None,
meta: None,
});
let (name, target) = extract_operation_details(&req);
assert_eq!(name, Some("complete"));
assert_eq!(target, Some("prompt:my-prompt".to_string()));
let req = McpRequest::Complete(CompleteParams {
reference: CompletionReference::Resource {
uri: "file:///test".to_string(),
},
argument: CompletionArgument::new("arg1", "val"),
context: None,
meta: None,
});
let (_, target) = extract_operation_details(&req);
assert_eq!(target, Some("resource:file:///test".to_string()));
}
#[test]
fn test_extract_operation_details_unknown_method() {
let req = McpRequest::Unknown {
method: "custom/method".to_string(),
params: None,
};
let (name, target) = extract_operation_details(&req);
assert_eq!(name, Some("unknown"));
assert_eq!(target, Some("custom/method".to_string()));
}
#[test]
fn test_layer_level_configuration() {
let layer = McpTracingLayer::new().level(Level::TRACE);
assert_eq!(layer.level, Level::TRACE);
let layer = McpTracingLayer::new().level(Level::ERROR);
assert_eq!(layer.level, Level::ERROR);
let layer = McpTracingLayer::new().level(Level::WARN);
assert_eq!(layer.level, Level::WARN);
}
}