use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use opentelemetry::Context;
use opentelemetry::propagation::{TextMapPropagator, Extractor, Injector};
use opentelemetry_sdk::propagation::TraceContextPropagator;
use serde_json::{json, Value};
use tracing::{debug, instrument, Span};
static INSTRUMENTATION_APPLIED: AtomicBool = AtomicBool::new(false);
#[derive(Debug, Clone)]
pub struct ItemWithContext<T> {
pub item: T,
pub context: HashMap<String, String>,
}
impl<T> ItemWithContext<T> {
pub fn new(item: T) -> Self {
Self {
item,
context: get_current_trace_context(),
}
}
pub fn with_context(item: T, context: HashMap<String, String>) -> Self {
Self { item, context }
}
pub fn into_inner(self) -> T {
self.item
}
}
pub fn init_mcp_instrumentation() {
if INSTRUMENTATION_APPLIED.swap(true, Ordering::SeqCst) {
debug!("MCP instrumentation already applied");
return;
}
debug!("Initializing MCP OpenTelemetry instrumentation");
}
pub fn is_instrumentation_applied() -> bool {
INSTRUMENTATION_APPLIED.load(Ordering::SeqCst)
}
#[cfg(test)]
pub fn reset_instrumentation() {
INSTRUMENTATION_APPLIED.store(false, Ordering::SeqCst);
}
static PROPAGATOR: std::sync::OnceLock<TraceContextPropagator> = std::sync::OnceLock::new();
fn get_propagator() -> &'static TraceContextPropagator {
PROPAGATOR.get_or_init(|| TraceContextPropagator::new())
}
pub fn get_current_trace_context() -> HashMap<String, String> {
let mut context = HashMap::new();
let otel_context = Context::current();
let mut injector = HashMapInjector(&mut context);
get_propagator().inject_context(&otel_context, &mut injector);
context
}
struct HashMapInjector<'a>(&'a mut HashMap<String, String>);
impl<'a> Injector for HashMapInjector<'a> {
fn set(&mut self, key: &str, value: String) {
self.0.insert(key.to_string(), value);
}
}
struct HashMapExtractor<'a>(&'a HashMap<String, String>);
impl<'a> Extractor for HashMapExtractor<'a> {
fn get(&self, key: &str) -> Option<&str> {
self.0.get(key).map(|s| s.as_str())
}
fn keys(&self) -> Vec<&str> {
self.0.keys().map(|s| s.as_str()).collect()
}
}
#[instrument(skip(request), level = "trace")]
pub fn inject_trace_context(request: &mut Value) {
let context = get_current_trace_context();
if context.is_empty() {
return;
}
let obj = match request.as_object_mut() {
Some(o) => o,
None => return,
};
let meta = obj
.entry("_meta")
.or_insert_with(|| json!({}));
if let Some(meta_obj) = meta.as_object_mut() {
for (key, value) in context {
meta_obj.insert(key, Value::String(value));
}
}
debug!("Injected trace context into MCP request");
}
#[instrument(skip(request), level = "trace")]
pub fn extract_trace_context(request: &Value) -> HashMap<String, String> {
let mut context = HashMap::new();
let meta = match request.get("_meta").and_then(|m| m.as_object()) {
Some(m) => m,
None => return context,
};
for key in &["traceparent", "tracestate", "baggage"] {
if let Some(value) = meta.get(*key).and_then(|v| v.as_str()) {
context.insert((*key).to_string(), value.to_string());
}
}
if !context.is_empty() {
debug!(keys = ?context.keys().collect::<Vec<_>>(), "Extracted trace context from MCP request");
}
context
}
pub fn extract_and_activate_context(request: &Value) -> ContextGuard {
let context_map = extract_trace_context(request);
if context_map.is_empty() {
return ContextGuard::new(Context::current());
}
let extractor = HashMapExtractor(&context_map);
let extracted_context = get_propagator().extract(&extractor);
let previous_context = Context::current();
let _guard = extracted_context.clone().attach();
ContextGuard::new(previous_context)
}
pub struct ContextGuard {
_active_guard: opentelemetry::ContextGuard,
}
impl ContextGuard {
fn new(context: Context) -> Self {
let _active_guard = context.attach();
Self { _active_guard }
}
}
#[must_use]
pub fn create_mcp_tool_span(tool_name: &str, request: &Value) -> (Span, ContextGuard) {
let context_guard = extract_and_activate_context(request);
let span = tracing::info_span!(
"mcp.tool.call",
tool.name = %tool_name,
otel.kind = "server",
mcp.instrumented = true
);
(span, context_guard)
}
pub trait InjectableContext {
fn inject_context(&mut self);
fn has_context(&self) -> bool;
}
impl InjectableContext for Value {
fn inject_context(&mut self) {
inject_trace_context(self);
}
fn has_context(&self) -> bool {
self.get("_meta")
.and_then(|m| m.get("traceparent"))
.is_some()
}
}
pub trait ExtractableContext {
fn extract_context(&self) -> HashMap<String, String>;
}
impl ExtractableContext for Value {
fn extract_context(&self) -> HashMap<String, String> {
extract_trace_context(self)
}
}
#[derive(Debug, Clone, Default)]
pub struct MCPInstrumentationConfig {
pub inject_client_context: bool,
pub extract_server_context: bool,
pub additional_headers: Vec<String>,
}
impl MCPInstrumentationConfig {
pub fn enabled() -> Self {
Self {
inject_client_context: true,
extract_server_context: true,
additional_headers: Vec::new(),
}
}
pub fn client_only() -> Self {
Self {
inject_client_context: true,
extract_server_context: false,
additional_headers: Vec::new(),
}
}
pub fn server_only() -> Self {
Self {
inject_client_context: false,
extract_server_context: true,
additional_headers: Vec::new(),
}
}
pub fn with_headers(mut self, headers: Vec<String>) -> Self {
self.additional_headers = headers;
self
}
}
pub struct InstrumentationGuard {
_config: MCPInstrumentationConfig,
}
impl InstrumentationGuard {
pub fn new(config: MCPInstrumentationConfig) -> Self {
init_mcp_instrumentation();
Self { _config: config }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_inject_trace_context() {
let mut request = json!({
"name": "test_tool",
"arguments": {}
});
inject_trace_context(&mut request);
}
#[test]
fn test_extract_trace_context_empty() {
let request = json!({
"name": "test_tool",
"arguments": {}
});
let context = extract_trace_context(&request);
assert!(context.is_empty());
}
#[test]
fn test_extract_trace_context_with_meta() {
let request = json!({
"name": "test_tool",
"_meta": {
"traceparent": "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01"
}
});
let context = extract_trace_context(&request);
assert_eq!(
context.get("traceparent"),
Some(&"00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01".to_string())
);
}
#[test]
fn test_injectable_context_trait() {
let mut request = json!({
"name": "test_tool"
});
assert!(!request.has_context());
request.inject_context();
assert!(!request.has_context());
}
#[test]
fn test_instrumentation_idempotent() {
reset_instrumentation();
assert!(!is_instrumentation_applied());
init_mcp_instrumentation();
assert!(is_instrumentation_applied());
init_mcp_instrumentation();
assert!(is_instrumentation_applied());
}
#[test]
fn test_item_with_context() {
let item = "test_message";
let wrapped = ItemWithContext::new(item);
assert_eq!(wrapped.item, "test_message");
assert_eq!(wrapped.into_inner(), "test_message");
}
#[test]
fn test_config_builders() {
let enabled = MCPInstrumentationConfig::enabled();
assert!(enabled.inject_client_context);
assert!(enabled.extract_server_context);
let client_only = MCPInstrumentationConfig::client_only();
assert!(client_only.inject_client_context);
assert!(!client_only.extract_server_context);
let server_only = MCPInstrumentationConfig::server_only();
assert!(!server_only.inject_client_context);
assert!(server_only.extract_server_context);
}
}