Skip to main content

mcp_proxy/
validation.rs

1//! Request validation middleware for the proxy.
2//!
3//! Validates incoming requests against configurable constraints before they
4//! reach backend services. Currently supports argument size limits for tool
5//! calls; additional validation rules can be added to [`ValidationConfig`]
6//! as needed.
7//!
8//! # Argument size validation
9//!
10//! When `max_argument_size` is set, the [`ValidationService`] serializes
11//! `CallTool` arguments to JSON and checks the byte length against the
12//! limit. Requests that exceed the limit are rejected immediately with an
13//! `invalid_params` JSON-RPC error containing the actual and maximum sizes.
14//! This prevents oversized payloads from reaching backends that may have
15//! their own (less informative) size limits.
16//!
17//! Non-`CallTool` requests (e.g., `ListTools`, `ReadResource`, `Ping`)
18//! pass through without validation. When `max_argument_size` is `None`,
19//! all requests pass through.
20//!
21//! # Configuration
22//!
23//! Argument size limits are configured in the `[security]` section of the
24//! TOML config:
25//!
26//! ```toml
27//! [security]
28//! max_argument_size = 1048576  # 1 MiB
29//! ```
30//!
31//! Omit `max_argument_size` (or set it to `null` in YAML) to disable
32//! argument size validation entirely.
33//!
34//! # Middleware stack position
35//!
36//! Validation runs early in the middleware stack -- after request coalescing
37//! but before capability filtering. This means oversized requests are
38//! rejected before any filtering or routing logic runs. The ordering in
39//! `proxy.rs`:
40//!
41//! 1. Request coalescing
42//! 2. **Request validation** (this module)
43//! 3. Capability filtering ([`crate::filter`])
44//! 4. Search-mode filtering ([`crate::filter`])
45//! 5. Tool aliasing ([`crate::alias`])
46//! 6. Composite tools ([`crate::composite`])
47
48use std::convert::Infallible;
49use std::future::Future;
50use std::pin::Pin;
51use std::sync::Arc;
52use std::task::{Context, Poll};
53
54use tower::{Layer, Service};
55
56use tower_mcp::protocol::McpRequest;
57
58/// Tower layer that produces a [`ValidationService`].
59#[derive(Clone)]
60pub struct ValidationLayer {
61    config: ValidationConfig,
62}
63
64impl ValidationLayer {
65    /// Create a new validation layer.
66    pub fn new(config: ValidationConfig) -> Self {
67        Self { config }
68    }
69}
70
71impl<S> Layer<S> for ValidationLayer {
72    type Service = ValidationService<S>;
73
74    fn layer(&self, inner: S) -> Self::Service {
75        ValidationService::new(inner, self.config.clone())
76    }
77}
78use tower_mcp::{RouterRequest, RouterResponse};
79use tower_mcp_types::JsonRpcError;
80
81/// Configuration for request validation.
82#[derive(Clone)]
83pub struct ValidationConfig {
84    /// Maximum serialized size of tool call arguments in bytes.
85    pub max_argument_size: Option<usize>,
86}
87
88/// Middleware that validates requests before forwarding.
89#[derive(Clone)]
90pub struct ValidationService<S> {
91    inner: S,
92    config: Arc<ValidationConfig>,
93}
94
95impl<S> ValidationService<S> {
96    /// Create a new validation service wrapping `inner`.
97    pub fn new(inner: S, config: ValidationConfig) -> Self {
98        Self {
99            inner,
100            config: Arc::new(config),
101        }
102    }
103}
104
105impl<S> Service<RouterRequest> for ValidationService<S>
106where
107    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
108        + Clone
109        + Send
110        + 'static,
111    S::Future: Send,
112{
113    type Response = RouterResponse;
114    type Error = Infallible;
115    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
116
117    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
118        self.inner.poll_ready(cx)
119    }
120
121    fn call(&mut self, req: RouterRequest) -> Self::Future {
122        let config = Arc::clone(&self.config);
123        let request_id = req.id.clone();
124
125        // Validate argument size for tool calls
126        if let McpRequest::CallTool(ref params) = req.inner
127            && let Some(max_size) = config.max_argument_size
128        {
129            let size = serde_json::to_string(&params.arguments)
130                .map(|s| s.len())
131                .unwrap_or(0);
132            if size > max_size {
133                return Box::pin(async move {
134                    Ok(RouterResponse {
135                        id: request_id,
136                        inner: Err(JsonRpcError::invalid_params(format!(
137                            "Tool arguments exceed maximum size: {} bytes (limit: {} bytes)",
138                            size, max_size
139                        ))),
140                    })
141                });
142            }
143        }
144
145        let fut = self.inner.call(req);
146        Box::pin(fut)
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use tower_mcp::protocol::McpRequest;
153
154    use super::{ValidationConfig, ValidationService};
155    use crate::test_util::{MockService, call_service};
156
157    #[tokio::test]
158    async fn test_validation_passes_small_arguments() {
159        let mock = MockService::with_tools(&["tool"]);
160        let config = ValidationConfig {
161            max_argument_size: Some(1024),
162        };
163        let mut svc = ValidationService::new(mock, config);
164
165        let resp = call_service(
166            &mut svc,
167            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
168                name: "tool".to_string(),
169                arguments: serde_json::json!({"key": "small"}),
170                meta: None,
171                task: None,
172            }),
173        )
174        .await;
175
176        assert!(resp.inner.is_ok(), "small args should pass validation");
177    }
178
179    #[tokio::test]
180    async fn test_validation_rejects_large_arguments() {
181        let mock = MockService::with_tools(&["tool"]);
182        let config = ValidationConfig {
183            max_argument_size: Some(10), // 10 bytes
184        };
185        let mut svc = ValidationService::new(mock, config);
186
187        let resp = call_service(
188            &mut svc,
189            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
190                name: "tool".to_string(),
191                arguments: serde_json::json!({"key": "this string is definitely longer than 10 bytes"}),
192                meta: None,
193                task: None,
194            }),
195        )
196        .await;
197
198        let err = resp.inner.unwrap_err();
199        assert!(
200            err.message.contains("exceed maximum size"),
201            "should mention size exceeded: {}",
202            err.message
203        );
204    }
205
206    #[tokio::test]
207    async fn test_validation_passes_non_tool_requests() {
208        let mock = MockService::with_tools(&["tool"]);
209        let config = ValidationConfig {
210            max_argument_size: Some(1),
211        };
212        let mut svc = ValidationService::new(mock, config);
213
214        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
215        assert!(resp.inner.is_ok(), "non-tool requests should pass");
216    }
217
218    #[tokio::test]
219    async fn test_validation_disabled_passes_everything() {
220        let mock = MockService::with_tools(&["tool"]);
221        let config = ValidationConfig {
222            max_argument_size: None,
223        };
224        let mut svc = ValidationService::new(mock, config);
225
226        let resp = call_service(
227            &mut svc,
228            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
229                name: "tool".to_string(),
230                arguments: serde_json::json!({"key": "any size is fine"}),
231                meta: None,
232                task: None,
233            }),
234        )
235        .await;
236
237        assert!(resp.inner.is_ok());
238    }
239}