use super::TelemetryLayerConfig;
use crate::attributes::McpSpanContext;
use crate::span_attributes::{MCP_DURATION_MS, MCP_ERROR_MESSAGE, MCP_STATUS};
use futures_util::future::BoxFuture;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Instant;
use tower_service::Service;
use tracing::{Instrument, Span, info};
#[derive(Debug, Clone)]
pub struct TelemetryService<S> {
inner: S,
config: Arc<TelemetryLayerConfig>,
}
impl<S> TelemetryService<S> {
pub fn new(inner: S, config: Arc<TelemetryLayerConfig>) -> Self {
Self { inner, config }
}
#[must_use]
pub fn inner(&self) -> &S {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut S {
&mut self.inner
}
#[must_use]
pub fn config(&self) -> &TelemetryLayerConfig {
&self.config
}
}
pub type TelemetryServiceFuture<F> = BoxFuture<'static, <F as Future>::Output>;
impl<S> Service<serde_json::Value> for TelemetryService<S>
where
S: Service<serde_json::Value, Response = serde_json::Value> + Clone + Send + 'static,
S::Future: Send,
S::Error: std::fmt::Display + Send,
{
type Response = serde_json::Value;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
#[allow(clippy::too_many_lines)]
fn call(&mut self, req: serde_json::Value) -> Self::Future {
let start = Instant::now();
let config = Arc::clone(&self.config);
let method = req
.get("method")
.and_then(|m| m.as_str())
.unwrap_or("unknown")
.to_string();
let request_id = if config.redact_request_id {
None
} else {
req.get("id").map(|id| match id {
serde_json::Value::Number(n) => n.to_string(),
serde_json::Value::String(s) => s.clone(),
_ => "unknown".to_string(),
})
};
if !config.should_instrument(&method) {
let inner = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, inner);
return Box::pin(async move { inner.call(req).await });
}
let tool_name = if method == "tools/call" {
req.get("params")
.and_then(|p| p.get("name"))
.and_then(|n| n.as_str())
.map(String::from)
} else {
None
};
let resource_uri = if method == "resources/read" && !config.redact_resource_uri {
req.get("params")
.and_then(|p| p.get("uri"))
.and_then(|u| u.as_str())
.map(String::from)
} else {
None
};
let prompt_name = if method == "prompts/get" {
req.get("params")
.and_then(|p| p.get("name"))
.and_then(|n| n.as_str())
.map(String::from)
} else {
None
};
let mut span_ctx = McpSpanContext::new()
.method(&method)
.server(&config.service_name, &config.service_version);
if let Some(ref id) = request_id {
span_ctx = span_ctx.request_id(id);
}
if let Some(ref name) = tool_name {
span_ctx = span_ctx.tool_name(name);
}
if let Some(ref uri) = resource_uri {
span_ctx = span_ctx.resource_uri(uri);
}
if let Some(ref name) = prompt_name {
span_ctx = span_ctx.prompt_name(name);
}
let span = span_ctx.into_span();
let request_size = if config.record_sizes {
Some(req.to_string().len())
} else {
None
};
let inner = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, inner);
Box::pin(
async move {
let result = inner.call(req).await;
let duration = start.elapsed();
let (success, error_msg) = match &result {
Ok(response) => {
let is_error = response.get("error").is_some();
if is_error {
let error_message = response
.get("error")
.and_then(|e| e.get("message"))
.and_then(|m| m.as_str())
.map(String::from);
(false, error_message)
} else {
(true, None)
}
}
Err(e) => (false, Some(e.to_string())),
};
if config.record_timing {
let current_span = Span::current();
let duration_ms = i64::try_from(duration.as_millis()).unwrap_or(i64::MAX);
current_span.record(MCP_DURATION_MS, duration_ms);
current_span.record(MCP_STATUS, if success { "success" } else { "error" });
if let Some(ref err) = error_msg {
let truncated = truncate_error_message(err, config.error_message_max_len);
if !truncated.is_empty() {
current_span.record(MCP_ERROR_MESSAGE, truncated.as_ref());
}
}
info!(
method = %method,
duration_ms = duration.as_millis(),
success = success,
request_size = request_size,
"MCP request completed"
);
}
result
}
.instrument(span),
)
}
}
impl<S, B> Service<http::Request<B>> for TelemetryService<S>
where
S: Service<http::Request<B>> + Clone + Send + 'static,
S::Response: Send,
S::Future: Send,
S::Error: std::fmt::Display + Send,
B: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: http::Request<B>) -> Self::Future {
let start = Instant::now();
let config = Arc::clone(&self.config);
let path = req.uri().path();
let method = path.strip_prefix('/').unwrap_or(path).to_string();
if !config.should_instrument(&method) {
let inner = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, inner);
return Box::pin(async move { inner.call(req).await });
}
let span = McpSpanContext::new()
.method(&method)
.transport("http")
.server(&config.service_name, &config.service_version)
.into_span();
let inner = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, inner);
Box::pin(
async move {
let result = inner.call(req).await;
let duration = start.elapsed();
let success = result.is_ok();
if config.record_timing {
let current_span = Span::current();
let duration_ms = i64::try_from(duration.as_millis()).unwrap_or(i64::MAX);
current_span.record(MCP_DURATION_MS, duration_ms);
current_span.record(MCP_STATUS, if success { "success" } else { "error" });
info!(
method = %method,
duration_ms = duration.as_millis(),
success = success,
"HTTP request completed"
);
}
result
}
.instrument(span),
)
}
}
fn truncate_error_message(msg: &str, max_len: usize) -> std::borrow::Cow<'_, str> {
if max_len == 0 {
return std::borrow::Cow::Borrowed("");
}
if msg.len() <= max_len {
return std::borrow::Cow::Borrowed(msg);
}
let mut end = max_len;
while end > 0 && !msg.is_char_boundary(end) {
end -= 1;
}
let mut out = String::with_capacity(end + 14);
out.push_str(&msg[..end]);
out.push_str("…[truncated]");
std::borrow::Cow::Owned(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_service_creation() {
let config = Arc::new(TelemetryLayerConfig::default());
assert!(config.record_timing);
assert!(config.record_sizes);
}
#[test]
fn test_truncate_error_message_short() {
let msg = "boom";
let out = truncate_error_message(msg, 512);
assert_eq!(out, "boom");
}
#[test]
fn test_truncate_error_message_long() {
let msg = "x".repeat(1024);
let out = truncate_error_message(&msg, 16);
assert!(out.starts_with("xxxxxxxxxxxxxxxx"));
assert!(out.ends_with("…[truncated]"));
}
#[test]
fn test_truncate_error_message_zero_len() {
let out = truncate_error_message("anything", 0);
assert_eq!(out, "");
}
#[test]
fn test_truncate_error_message_utf8_boundary() {
let msg = "héllo world";
let out = truncate_error_message(msg, 5);
assert!(out.starts_with("héll"));
}
}