Skip to main content

mcp_proxy/
validation.rs

1//! Request validation middleware for the proxy.
2//!
3//! Validates tool call arguments against size limits before forwarding.
4
5use std::convert::Infallible;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use tower::{Layer, Service};
12
13use tower_mcp::protocol::McpRequest;
14
15/// Tower layer that produces a [`ValidationService`].
16#[derive(Clone)]
17pub struct ValidationLayer {
18    config: ValidationConfig,
19}
20
21impl ValidationLayer {
22    /// Create a new validation layer.
23    pub fn new(config: ValidationConfig) -> Self {
24        Self { config }
25    }
26}
27
28impl<S> Layer<S> for ValidationLayer {
29    type Service = ValidationService<S>;
30
31    fn layer(&self, inner: S) -> Self::Service {
32        ValidationService::new(inner, self.config.clone())
33    }
34}
35use tower_mcp::{RouterRequest, RouterResponse};
36use tower_mcp_types::JsonRpcError;
37
38/// Configuration for request validation.
39#[derive(Clone)]
40pub struct ValidationConfig {
41    /// Maximum serialized size of tool call arguments in bytes.
42    pub max_argument_size: Option<usize>,
43}
44
45/// Middleware that validates requests before forwarding.
46#[derive(Clone)]
47pub struct ValidationService<S> {
48    inner: S,
49    config: Arc<ValidationConfig>,
50}
51
52impl<S> ValidationService<S> {
53    /// Create a new validation service wrapping `inner`.
54    pub fn new(inner: S, config: ValidationConfig) -> Self {
55        Self {
56            inner,
57            config: Arc::new(config),
58        }
59    }
60}
61
62impl<S> Service<RouterRequest> for ValidationService<S>
63where
64    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
65        + Clone
66        + Send
67        + 'static,
68    S::Future: Send,
69{
70    type Response = RouterResponse;
71    type Error = Infallible;
72    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
73
74    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
75        self.inner.poll_ready(cx)
76    }
77
78    fn call(&mut self, req: RouterRequest) -> Self::Future {
79        let config = Arc::clone(&self.config);
80        let request_id = req.id.clone();
81
82        // Validate argument size for tool calls
83        if let McpRequest::CallTool(ref params) = req.inner
84            && let Some(max_size) = config.max_argument_size
85        {
86            let size = serde_json::to_string(&params.arguments)
87                .map(|s| s.len())
88                .unwrap_or(0);
89            if size > max_size {
90                return Box::pin(async move {
91                    Ok(RouterResponse {
92                        id: request_id,
93                        inner: Err(JsonRpcError::invalid_params(format!(
94                            "Tool arguments exceed maximum size: {} bytes (limit: {} bytes)",
95                            size, max_size
96                        ))),
97                    })
98                });
99            }
100        }
101
102        let fut = self.inner.call(req);
103        Box::pin(fut)
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use tower_mcp::protocol::McpRequest;
110
111    use super::{ValidationConfig, ValidationService};
112    use crate::test_util::{MockService, call_service};
113
114    #[tokio::test]
115    async fn test_validation_passes_small_arguments() {
116        let mock = MockService::with_tools(&["tool"]);
117        let config = ValidationConfig {
118            max_argument_size: Some(1024),
119        };
120        let mut svc = ValidationService::new(mock, config);
121
122        let resp = call_service(
123            &mut svc,
124            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
125                name: "tool".to_string(),
126                arguments: serde_json::json!({"key": "small"}),
127                meta: None,
128                task: None,
129            }),
130        )
131        .await;
132
133        assert!(resp.inner.is_ok(), "small args should pass validation");
134    }
135
136    #[tokio::test]
137    async fn test_validation_rejects_large_arguments() {
138        let mock = MockService::with_tools(&["tool"]);
139        let config = ValidationConfig {
140            max_argument_size: Some(10), // 10 bytes
141        };
142        let mut svc = ValidationService::new(mock, config);
143
144        let resp = call_service(
145            &mut svc,
146            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
147                name: "tool".to_string(),
148                arguments: serde_json::json!({"key": "this string is definitely longer than 10 bytes"}),
149                meta: None,
150                task: None,
151            }),
152        )
153        .await;
154
155        let err = resp.inner.unwrap_err();
156        assert!(
157            err.message.contains("exceed maximum size"),
158            "should mention size exceeded: {}",
159            err.message
160        );
161    }
162
163    #[tokio::test]
164    async fn test_validation_passes_non_tool_requests() {
165        let mock = MockService::with_tools(&["tool"]);
166        let config = ValidationConfig {
167            max_argument_size: Some(1),
168        };
169        let mut svc = ValidationService::new(mock, config);
170
171        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
172        assert!(resp.inner.is_ok(), "non-tool requests should pass");
173    }
174
175    #[tokio::test]
176    async fn test_validation_disabled_passes_everything() {
177        let mock = MockService::with_tools(&["tool"]);
178        let config = ValidationConfig {
179            max_argument_size: None,
180        };
181        let mut svc = ValidationService::new(mock, config);
182
183        let resp = call_service(
184            &mut svc,
185            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
186                name: "tool".to_string(),
187                arguments: serde_json::json!({"key": "any size is fine"}),
188                meta: None,
189                task: None,
190            }),
191        )
192        .await;
193
194        assert!(resp.inner.is_ok());
195    }
196}