1use std::convert::Infallible;
49use std::future::Future;
50use std::pin::Pin;
51use std::sync::Arc;
52use std::task::{Context, Poll};
53
54use tower::{Layer, Service};
55
56use tower_mcp::protocol::McpRequest;
57
58#[derive(Clone)]
60pub struct ValidationLayer {
61 config: ValidationConfig,
62}
63
64impl ValidationLayer {
65 pub fn new(config: ValidationConfig) -> Self {
67 Self { config }
68 }
69}
70
71impl<S> Layer<S> for ValidationLayer {
72 type Service = ValidationService<S>;
73
74 fn layer(&self, inner: S) -> Self::Service {
75 ValidationService::new(inner, self.config.clone())
76 }
77}
78use tower_mcp::{RouterRequest, RouterResponse};
79use tower_mcp_types::JsonRpcError;
80
81#[derive(Clone)]
83pub struct ValidationConfig {
84 pub max_argument_size: Option<usize>,
86}
87
88#[derive(Clone)]
90pub struct ValidationService<S> {
91 inner: S,
92 config: Arc<ValidationConfig>,
93}
94
95impl<S> ValidationService<S> {
96 pub fn new(inner: S, config: ValidationConfig) -> Self {
98 Self {
99 inner,
100 config: Arc::new(config),
101 }
102 }
103}
104
105impl<S> Service<RouterRequest> for ValidationService<S>
106where
107 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
108 + Clone
109 + Send
110 + 'static,
111 S::Future: Send,
112{
113 type Response = RouterResponse;
114 type Error = Infallible;
115 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
116
117 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
118 self.inner.poll_ready(cx)
119 }
120
121 fn call(&mut self, req: RouterRequest) -> Self::Future {
122 let config = Arc::clone(&self.config);
123 let request_id = req.id.clone();
124
125 if let McpRequest::CallTool(ref params) = req.inner
127 && let Some(max_size) = config.max_argument_size
128 {
129 let size = serde_json::to_string(¶ms.arguments)
130 .map(|s| s.len())
131 .unwrap_or(0);
132 if size > max_size {
133 return Box::pin(async move {
134 Ok(RouterResponse {
135 id: request_id,
136 inner: Err(JsonRpcError::invalid_params(format!(
137 "Tool arguments exceed maximum size: {} bytes (limit: {} bytes)",
138 size, max_size
139 ))),
140 })
141 });
142 }
143 }
144
145 let fut = self.inner.call(req);
146 Box::pin(fut)
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use tower_mcp::protocol::McpRequest;
153
154 use super::{ValidationConfig, ValidationService};
155 use crate::test_util::{MockService, call_service};
156
157 #[tokio::test]
158 async fn test_validation_passes_small_arguments() {
159 let mock = MockService::with_tools(&["tool"]);
160 let config = ValidationConfig {
161 max_argument_size: Some(1024),
162 };
163 let mut svc = ValidationService::new(mock, config);
164
165 let resp = call_service(
166 &mut svc,
167 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
168 name: "tool".to_string(),
169 arguments: serde_json::json!({"key": "small"}),
170 meta: None,
171 task: None,
172 }),
173 )
174 .await;
175
176 assert!(resp.inner.is_ok(), "small args should pass validation");
177 }
178
179 #[tokio::test]
180 async fn test_validation_rejects_large_arguments() {
181 let mock = MockService::with_tools(&["tool"]);
182 let config = ValidationConfig {
183 max_argument_size: Some(10), };
185 let mut svc = ValidationService::new(mock, config);
186
187 let resp = call_service(
188 &mut svc,
189 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
190 name: "tool".to_string(),
191 arguments: serde_json::json!({"key": "this string is definitely longer than 10 bytes"}),
192 meta: None,
193 task: None,
194 }),
195 )
196 .await;
197
198 let err = resp.inner.unwrap_err();
199 assert!(
200 err.message.contains("exceed maximum size"),
201 "should mention size exceeded: {}",
202 err.message
203 );
204 }
205
206 #[tokio::test]
207 async fn test_validation_passes_non_tool_requests() {
208 let mock = MockService::with_tools(&["tool"]);
209 let config = ValidationConfig {
210 max_argument_size: Some(1),
211 };
212 let mut svc = ValidationService::new(mock, config);
213
214 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
215 assert!(resp.inner.is_ok(), "non-tool requests should pass");
216 }
217
218 #[tokio::test]
219 async fn test_validation_disabled_passes_everything() {
220 let mock = MockService::with_tools(&["tool"]);
221 let config = ValidationConfig {
222 max_argument_size: None,
223 };
224 let mut svc = ValidationService::new(mock, config);
225
226 let resp = call_service(
227 &mut svc,
228 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
229 name: "tool".to_string(),
230 arguments: serde_json::json!({"key": "any size is fine"}),
231 meta: None,
232 task: None,
233 }),
234 )
235 .await;
236
237 assert!(resp.inner.is_ok());
238 }
239}