use crate::error::{Error, Result};
use crate::server::cancellation::RequestHandlerExtra;
use async_trait::async_trait;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct ToolContext {
pub tool_name: String,
pub session_id: Option<String>,
pub request_id: String,
pub metadata: HashMap<String, String>,
}
impl ToolContext {
pub fn new(tool_name: impl Into<String>, request_id: impl Into<String>) -> Self {
Self {
tool_name: tool_name.into(),
session_id: None,
request_id: request_id.into(),
metadata: HashMap::new(),
}
}
pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
self.session_id = Some(session_id.into());
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn get_metadata(&self, key: &str) -> Option<&String> {
self.metadata.get(key)
}
pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.metadata.insert(key.into(), value.into());
}
}
#[async_trait]
pub trait ToolMiddleware: Send + Sync {
async fn on_request(
&self,
tool_name: &str,
args: &mut Value,
extra: &mut RequestHandlerExtra,
context: &ToolContext,
) -> Result<()> {
let _ = (tool_name, args, extra, context);
Ok(())
}
async fn on_response(
&self,
tool_name: &str,
result: &mut Result<Value>,
context: &ToolContext,
) -> Result<()> {
let _ = (tool_name, result, context);
Ok(())
}
async fn on_error(&self, tool_name: &str, error: &Error, context: &ToolContext) -> Result<()> {
let _ = (tool_name, error, context);
Ok(())
}
fn priority(&self) -> i32 {
50
}
async fn should_execute(&self, _context: &ToolContext) -> bool {
true
}
}
pub struct ToolMiddlewareChain {
middlewares: Vec<Arc<dyn ToolMiddleware>>,
}
impl ToolMiddlewareChain {
pub fn new() -> Self {
Self {
middlewares: Vec::new(),
}
}
pub fn add(&mut self, middleware: Arc<dyn ToolMiddleware>) {
self.middlewares.push(middleware);
self.middlewares.sort_by_key(|m| m.priority());
}
pub async fn process_request(
&self,
tool_name: &str,
args: &mut Value,
extra: &mut RequestHandlerExtra,
context: &ToolContext,
) -> Result<()> {
for middleware in &self.middlewares {
if middleware.should_execute(context).await {
if let Err(e) = middleware.on_request(tool_name, args, extra, context).await {
self.handle_error(tool_name, &e, context).await;
return Err(e);
}
}
}
Ok(())
}
pub async fn process_response(
&self,
tool_name: &str,
result: &mut Result<Value>,
context: &ToolContext,
) -> Result<()> {
for middleware in self.middlewares.iter().rev() {
if middleware.should_execute(context).await {
if let Err(e) = middleware.on_response(tool_name, result, context).await {
self.handle_error(tool_name, &e, context).await;
return Err(e);
}
}
}
Ok(())
}
async fn handle_error(&self, tool_name: &str, error: &Error, context: &ToolContext) {
for middleware in &self.middlewares {
if let Err(e) = middleware.on_error(tool_name, error, context).await {
tracing::error!(
"Error in tool middleware on_error hook: {} (original error: {})",
e,
error
);
}
}
}
pub async fn handle_tool_error(&self, tool_name: &str, error: &Error, context: &ToolContext) {
self.handle_error(tool_name, error, context).await;
}
}
impl Default for ToolMiddlewareChain {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for ToolMiddlewareChain {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolMiddlewareChain")
.field("count", &self.middlewares.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio_util::sync::CancellationToken;
struct TestMiddleware {
priority: i32,
name: String,
}
#[async_trait]
impl ToolMiddleware for TestMiddleware {
async fn on_request(
&self,
_tool_name: &str,
args: &mut Value,
_extra: &mut RequestHandlerExtra,
_context: &ToolContext,
) -> Result<()> {
if let Value::Object(map) = args {
map.insert(format!("{}_executed", self.name), Value::Bool(true));
}
Ok(())
}
fn priority(&self) -> i32 {
self.priority
}
}
#[tokio::test]
async fn test_middleware_chain_priority_ordering() {
let mut chain = ToolMiddlewareChain::new();
chain.add(Arc::new(TestMiddleware {
priority: 100,
name: "third".to_string(),
}));
chain.add(Arc::new(TestMiddleware {
priority: 10,
name: "first".to_string(),
}));
chain.add(Arc::new(TestMiddleware {
priority: 50,
name: "second".to_string(),
}));
let mut args = serde_json::json!({});
let mut extra =
RequestHandlerExtra::new("test-request".to_string(), CancellationToken::new());
let context = ToolContext::new("test_tool", "req-123");
chain
.process_request("test_tool", &mut args, &mut extra, &context)
.await
.unwrap();
let map = args.as_object().unwrap();
assert!(map.contains_key("first_executed"));
assert!(map.contains_key("second_executed"));
assert!(map.contains_key("third_executed"));
}
#[tokio::test]
async fn test_middleware_chain_short_circuit_on_error() {
struct FailingMiddleware;
#[async_trait]
impl ToolMiddleware for FailingMiddleware {
async fn on_request(
&self,
_tool_name: &str,
_args: &mut Value,
_extra: &mut RequestHandlerExtra,
_context: &ToolContext,
) -> Result<()> {
Err(Error::protocol(
crate::ErrorCode::INVALID_PARAMS,
"Middleware failed",
))
}
fn priority(&self) -> i32 {
50
}
}
let mut chain = ToolMiddlewareChain::new();
chain.add(Arc::new(FailingMiddleware));
let mut args = serde_json::json!({});
let mut extra =
RequestHandlerExtra::new("test-request".to_string(), CancellationToken::new());
let context = ToolContext::new("test_tool", "req-123");
let result = chain
.process_request("test_tool", &mut args, &mut extra, &context)
.await;
assert!(result.is_err());
let error_string = result.unwrap_err().to_string();
assert!(
error_string.contains("Middleware failed"),
"Expected error to contain 'Middleware failed', got: {}",
error_string
);
}
#[tokio::test]
async fn test_tool_context() {
let context = ToolContext::new("test_tool", "req-123")
.with_session_id("session-456")
.with_metadata("key1", "value1");
assert_eq!(context.tool_name, "test_tool");
assert_eq!(context.request_id, "req-123");
assert_eq!(context.session_id, Some("session-456".to_string()));
assert_eq!(context.get_metadata("key1"), Some(&"value1".to_string()));
}
#[tokio::test]
async fn test_oauth_injection_flow() {
use crate::server::auth::AuthContext;
struct OAuthInjectionMiddleware;
#[async_trait]
impl ToolMiddleware for OAuthInjectionMiddleware {
async fn on_request(
&self,
_tool_name: &str,
_args: &mut Value,
extra: &mut RequestHandlerExtra,
_context: &ToolContext,
) -> Result<()> {
if let Some(auth_ctx) = &extra.auth_context {
if let Some(token) = &auth_ctx.token {
extra.set_metadata("oauth_token".to_string(), token.clone());
}
}
Ok(())
}
fn priority(&self) -> i32 {
10 }
}
let mut chain = ToolMiddlewareChain::new();
chain.add(Arc::new(OAuthInjectionMiddleware));
let mut extra = RequestHandlerExtra::new("test-req".to_string(), CancellationToken::new())
.with_auth_context(Some(AuthContext {
subject: "user-123".to_string(),
scopes: vec!["read".to_string(), "write".to_string()],
claims: std::collections::HashMap::new(),
token: Some("oauth-token-abc123".to_string()),
client_id: Some("client-456".to_string()),
expires_at: None,
authenticated: true,
}));
let mut args = serde_json::json!({});
let context = ToolContext::new("test_tool", "req-123");
chain
.process_request("test_tool", &mut args, &mut extra, &context)
.await
.unwrap();
assert_eq!(
extra.get_metadata("oauth_token"),
Some(&"oauth-token-abc123".to_string())
);
}
#[tokio::test]
async fn test_conditional_execution_by_context() {
struct SessionAwareMiddleware {
executed: Arc<parking_lot::Mutex<bool>>,
}
#[async_trait]
impl ToolMiddleware for SessionAwareMiddleware {
async fn on_request(
&self,
_tool_name: &str,
_args: &mut Value,
_extra: &mut RequestHandlerExtra,
_context: &ToolContext,
) -> Result<()> {
*self.executed.lock() = true;
Ok(())
}
async fn should_execute(&self, context: &ToolContext) -> bool {
context.session_id.is_some() && context.tool_name.starts_with("api_")
}
}
let executed = Arc::new(parking_lot::Mutex::new(false));
let middleware = SessionAwareMiddleware {
executed: executed.clone(),
};
let mut chain = ToolMiddlewareChain::new();
chain.add(Arc::new(middleware));
*executed.lock() = false;
let mut extra = RequestHandlerExtra::new("test-req".to_string(), CancellationToken::new());
let mut args = serde_json::json!({});
let context = ToolContext::new("api_call", "req-123");
chain
.process_request("api_call", &mut args, &mut extra, &context)
.await
.unwrap();
assert!(!*executed.lock(), "Should not execute without session_id");
*executed.lock() = false;
let context = ToolContext::new("db_query", "req-124").with_session_id("session-456");
chain
.process_request("db_query", &mut args, &mut extra, &context)
.await
.unwrap();
assert!(
!*executed.lock(),
"Should not execute for non-api_ tool names"
);
*executed.lock() = false;
let context = ToolContext::new("api_call", "req-125").with_session_id("session-456");
chain
.process_request("api_call", &mut args, &mut extra, &context)
.await
.unwrap();
assert!(
*executed.lock(),
"Should execute with session_id and api_ prefix"
);
}
#[tokio::test]
async fn test_concurrent_middleware_execution() {
use std::sync::atomic::{AtomicUsize, Ordering};
struct CountingMiddleware {
counter: Arc<AtomicUsize>,
}
#[async_trait]
impl ToolMiddleware for CountingMiddleware {
async fn on_request(
&self,
_tool_name: &str,
_args: &mut Value,
extra: &mut RequestHandlerExtra,
_context: &ToolContext,
) -> Result<()> {
let count = self.counter.fetch_add(1, Ordering::SeqCst);
extra.set_metadata("execution_number".to_string(), count.to_string());
Ok(())
}
}
let counter = Arc::new(AtomicUsize::new(0));
let middleware = CountingMiddleware {
counter: counter.clone(),
};
let mut chain_builder = ToolMiddlewareChain::new();
chain_builder.add(Arc::new(middleware));
let chain = Arc::new(chain_builder);
let mut handles = Vec::new();
for i in 0..100 {
let chain = chain.clone();
let handle = tokio::spawn(async move {
let mut extra =
RequestHandlerExtra::new(format!("req-{}", i), CancellationToken::new());
let mut args = serde_json::json!({});
let context = ToolContext::new("test_tool", format!("req-{}", i));
chain
.process_request("test_tool", &mut args, &mut extra, &context)
.await
.unwrap();
assert!(extra.get_metadata("execution_number").is_some());
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
assert_eq!(counter.load(Ordering::SeqCst), 100);
}
#[tokio::test]
async fn test_no_leak_in_logging() {
let mut extra = RequestHandlerExtra::new("test-req".to_string(), CancellationToken::new());
extra.set_metadata(
"oauth_token".to_string(),
"super-secret-token-xyz".to_string(),
);
extra.set_metadata("user_id".to_string(), "user-456".to_string());
let debug_output = format!("{:?}", extra);
assert!(
debug_output.contains("[REDACTED]"),
"Expected [REDACTED] in debug output"
);
assert!(
!debug_output.contains("super-secret-token-xyz"),
"Token should not appear in debug output: {}",
debug_output
);
assert!(
debug_output.contains("user-456"),
"Non-sensitive metadata should be visible: {}",
debug_output
);
}
}