Skip to main content

mcp_proxy/
validation.rs

1//! Request validation middleware for the proxy.
2//!
3//! Validates tool call arguments against size limits before forwarding.
4
5use std::convert::Infallible;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use tower::Service;
12
13use tower_mcp::protocol::McpRequest;
14use tower_mcp::{RouterRequest, RouterResponse};
15use tower_mcp_types::JsonRpcError;
16
17/// Configuration for request validation.
18#[derive(Clone)]
19pub struct ValidationConfig {
20    /// Maximum serialized size of tool call arguments in bytes.
21    pub max_argument_size: Option<usize>,
22}
23
24/// Middleware that validates requests before forwarding.
25#[derive(Clone)]
26pub struct ValidationService<S> {
27    inner: S,
28    config: Arc<ValidationConfig>,
29}
30
31impl<S> ValidationService<S> {
32    /// Create a new validation service wrapping `inner`.
33    pub fn new(inner: S, config: ValidationConfig) -> Self {
34        Self {
35            inner,
36            config: Arc::new(config),
37        }
38    }
39}
40
41impl<S> Service<RouterRequest> for ValidationService<S>
42where
43    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
44        + Clone
45        + Send
46        + 'static,
47    S::Future: Send,
48{
49    type Response = RouterResponse;
50    type Error = Infallible;
51    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
52
53    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
54        self.inner.poll_ready(cx)
55    }
56
57    fn call(&mut self, req: RouterRequest) -> Self::Future {
58        let config = Arc::clone(&self.config);
59        let request_id = req.id.clone();
60
61        // Validate argument size for tool calls
62        if let McpRequest::CallTool(ref params) = req.inner
63            && let Some(max_size) = config.max_argument_size
64        {
65            let size = serde_json::to_string(&params.arguments)
66                .map(|s| s.len())
67                .unwrap_or(0);
68            if size > max_size {
69                return Box::pin(async move {
70                    Ok(RouterResponse {
71                        id: request_id,
72                        inner: Err(JsonRpcError::invalid_params(format!(
73                            "Tool arguments exceed maximum size: {} bytes (limit: {} bytes)",
74                            size, max_size
75                        ))),
76                    })
77                });
78            }
79        }
80
81        let fut = self.inner.call(req);
82        Box::pin(fut)
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use tower_mcp::protocol::McpRequest;
89
90    use super::{ValidationConfig, ValidationService};
91    use crate::test_util::{MockService, call_service};
92
93    #[tokio::test]
94    async fn test_validation_passes_small_arguments() {
95        let mock = MockService::with_tools(&["tool"]);
96        let config = ValidationConfig {
97            max_argument_size: Some(1024),
98        };
99        let mut svc = ValidationService::new(mock, config);
100
101        let resp = call_service(
102            &mut svc,
103            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
104                name: "tool".to_string(),
105                arguments: serde_json::json!({"key": "small"}),
106                meta: None,
107                task: None,
108            }),
109        )
110        .await;
111
112        assert!(resp.inner.is_ok(), "small args should pass validation");
113    }
114
115    #[tokio::test]
116    async fn test_validation_rejects_large_arguments() {
117        let mock = MockService::with_tools(&["tool"]);
118        let config = ValidationConfig {
119            max_argument_size: Some(10), // 10 bytes
120        };
121        let mut svc = ValidationService::new(mock, config);
122
123        let resp = call_service(
124            &mut svc,
125            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
126                name: "tool".to_string(),
127                arguments: serde_json::json!({"key": "this string is definitely longer than 10 bytes"}),
128                meta: None,
129                task: None,
130            }),
131        )
132        .await;
133
134        let err = resp.inner.unwrap_err();
135        assert!(
136            err.message.contains("exceed maximum size"),
137            "should mention size exceeded: {}",
138            err.message
139        );
140    }
141
142    #[tokio::test]
143    async fn test_validation_passes_non_tool_requests() {
144        let mock = MockService::with_tools(&["tool"]);
145        let config = ValidationConfig {
146            max_argument_size: Some(1),
147        };
148        let mut svc = ValidationService::new(mock, config);
149
150        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
151        assert!(resp.inner.is_ok(), "non-tool requests should pass");
152    }
153
154    #[tokio::test]
155    async fn test_validation_disabled_passes_everything() {
156        let mock = MockService::with_tools(&["tool"]);
157        let config = ValidationConfig {
158            max_argument_size: None,
159        };
160        let mut svc = ValidationService::new(mock, config);
161
162        let resp = call_service(
163            &mut svc,
164            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
165                name: "tool".to_string(),
166                arguments: serde_json::json!({"key": "any size is fine"}),
167                meta: None,
168                task: None,
169            }),
170        )
171        .await;
172
173        assert!(resp.inner.is_ok());
174    }
175}