use std::sync::Arc;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::error::McpAdapterError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpToolCallRequest {
pub tool_name: String,
pub server_name: String,
pub arguments: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpToolCallResult {
pub value: Value,
pub is_error: bool,
}
pub type InterceptorNext<'a> = Box<
dyn FnOnce(
McpToolCallRequest,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>>
+ Send
+ 'a,
>,
> + Send
+ 'a,
>;
pub trait ToolCallInterceptor: Send + Sync {
fn intercept<'a>(
&'a self,
request: McpToolCallRequest,
next: InterceptorNext<'a>,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>>
+ Send
+ 'a,
>,
>;
}
pub async fn run_interceptor_chain<F>(
interceptors: &[Arc<dyn ToolCallInterceptor>],
request: McpToolCallRequest,
inner: F,
) -> Result<McpToolCallResult, McpAdapterError>
where
F: FnOnce(
McpToolCallRequest,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>> + Send,
>,
> + Send
+ 'static,
{
if interceptors.is_empty() {
return inner(request).await;
}
run_chain_from(interceptors, 0, request, inner).await
}
fn run_chain_from<'a, F>(
interceptors: &'a [Arc<dyn ToolCallInterceptor>],
index: usize,
request: McpToolCallRequest,
inner: F,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>> + Send + 'a>,
>
where
F: FnOnce(
McpToolCallRequest,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>> + Send,
>,
> + Send
+ 'static,
{
if index >= interceptors.len() {
return Box::pin(async move { inner(request).await });
}
let interceptor = Arc::clone(&interceptors[index]);
let remaining = interceptors;
let next_index = index + 1;
Box::pin(async move {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
interceptor.intercept(
request,
Box::new(move |req| run_chain_from(remaining, next_index, req, inner)),
)
}));
match result {
Ok(fut) => fut.await,
Err(payload) => {
let msg = payload
.downcast_ref::<&str>()
.map(|s| (*s).to_owned())
.or_else(|| payload.downcast_ref::<String>().cloned())
.unwrap_or_else(|| "unknown panic".to_owned());
Err(McpAdapterError::InterceptorPanic { message: msg })
}
}
})
}
#[derive(Debug, Default)]
pub struct LoggingInterceptor;
impl ToolCallInterceptor for LoggingInterceptor {
fn intercept<'a>(
&'a self,
request: McpToolCallRequest,
next: InterceptorNext<'a>,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>>
+ Send
+ 'a,
>,
> {
Box::pin(async move {
tracing::debug!(
tool = %request.tool_name,
server = %request.server_name,
"MCP tool call intercepted"
);
let result = next(request).await;
match &result {
Ok(r) => tracing::debug!(is_error = r.is_error, "MCP tool call completed"),
Err(e) => tracing::warn!(error = %e, "MCP tool call failed in interceptor chain"),
}
result
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
struct RecordingInterceptor {
id: char,
order: Arc<tokio::sync::Mutex<Vec<char>>>,
}
impl ToolCallInterceptor for RecordingInterceptor {
fn intercept<'a>(
&'a self,
request: McpToolCallRequest,
next: InterceptorNext<'a>,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>>
+ Send
+ 'a,
>,
> {
let id = self.id;
let order = Arc::clone(&self.order);
Box::pin(async move {
order.lock().await.push(id);
let result = next(request).await;
order.lock().await.push(id);
result
})
}
}
struct ShortCircuitInterceptor;
impl ToolCallInterceptor for ShortCircuitInterceptor {
fn intercept<'a>(
&'a self,
_request: McpToolCallRequest,
_next: InterceptorNext<'a>,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>>
+ Send
+ 'a,
>,
> {
Box::pin(async {
Ok(McpToolCallResult {
value: serde_json::json!({"short": "circuit"}),
is_error: false,
})
})
}
}
fn make_inner() -> impl FnOnce(
McpToolCallRequest,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<McpToolCallResult, McpAdapterError>> + Send>,
> + Send
+ 'static {
|_req| {
Box::pin(async {
Ok(McpToolCallResult {
value: serde_json::json!({"result": "ok"}),
is_error: false,
})
})
}
}
fn make_request() -> McpToolCallRequest {
McpToolCallRequest {
tool_name: "search".into(),
server_name: "s1".into(),
arguments: serde_json::json!({}),
}
}
#[tokio::test]
async fn onion_ordering_abc_to_cba() {
let order: Arc<tokio::sync::Mutex<Vec<char>>> =
Arc::new(tokio::sync::Mutex::new(Vec::new()));
let interceptors: Vec<Arc<dyn ToolCallInterceptor>> = vec![
Arc::new(RecordingInterceptor {
id: 'A',
order: Arc::clone(&order),
}),
Arc::new(RecordingInterceptor {
id: 'B',
order: Arc::clone(&order),
}),
Arc::new(RecordingInterceptor {
id: 'C',
order: Arc::clone(&order),
}),
];
let result = run_interceptor_chain(&interceptors, make_request(), make_inner()).await;
assert!(result.is_ok());
let sequence = order.lock().await.clone();
assert_eq!(sequence, vec!['A', 'B', 'C', 'C', 'B', 'A']);
}
#[tokio::test]
async fn short_circuit_interceptor_stops_chain() {
let order: Arc<tokio::sync::Mutex<Vec<char>>> =
Arc::new(tokio::sync::Mutex::new(Vec::new()));
let interceptors: Vec<Arc<dyn ToolCallInterceptor>> = vec![
Arc::new(RecordingInterceptor {
id: 'A',
order: Arc::clone(&order),
}),
Arc::new(ShortCircuitInterceptor),
Arc::new(RecordingInterceptor {
id: 'C',
order: Arc::clone(&order),
}),
];
let result = run_interceptor_chain(&interceptors, make_request(), make_inner()).await;
assert!(result.is_ok());
let r = result.unwrap();
assert_eq!(r.value["short"], "circuit");
let sequence = order.lock().await.clone();
assert_eq!(sequence, vec!['A', 'A']);
}
#[tokio::test]
async fn no_interceptors_calls_inner() {
let result = run_interceptor_chain(&[], make_request(), make_inner()).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().value["result"], "ok");
}
}