1use std::convert::Infallible;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use tower::{Layer, Service};
12
13use tower_mcp::protocol::McpRequest;
14
15#[derive(Clone)]
17pub struct ValidationLayer {
18 config: ValidationConfig,
19}
20
21impl ValidationLayer {
22 pub fn new(config: ValidationConfig) -> Self {
24 Self { config }
25 }
26}
27
28impl<S> Layer<S> for ValidationLayer {
29 type Service = ValidationService<S>;
30
31 fn layer(&self, inner: S) -> Self::Service {
32 ValidationService::new(inner, self.config.clone())
33 }
34}
35use tower_mcp::{RouterRequest, RouterResponse};
36use tower_mcp_types::JsonRpcError;
37
38#[derive(Clone)]
40pub struct ValidationConfig {
41 pub max_argument_size: Option<usize>,
43}
44
45#[derive(Clone)]
47pub struct ValidationService<S> {
48 inner: S,
49 config: Arc<ValidationConfig>,
50}
51
52impl<S> ValidationService<S> {
53 pub fn new(inner: S, config: ValidationConfig) -> Self {
55 Self {
56 inner,
57 config: Arc::new(config),
58 }
59 }
60}
61
62impl<S> Service<RouterRequest> for ValidationService<S>
63where
64 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
65 + Clone
66 + Send
67 + 'static,
68 S::Future: Send,
69{
70 type Response = RouterResponse;
71 type Error = Infallible;
72 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
73
74 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
75 self.inner.poll_ready(cx)
76 }
77
78 fn call(&mut self, req: RouterRequest) -> Self::Future {
79 let config = Arc::clone(&self.config);
80 let request_id = req.id.clone();
81
82 if let McpRequest::CallTool(ref params) = req.inner
84 && let Some(max_size) = config.max_argument_size
85 {
86 let size = serde_json::to_string(¶ms.arguments)
87 .map(|s| s.len())
88 .unwrap_or(0);
89 if size > max_size {
90 return Box::pin(async move {
91 Ok(RouterResponse {
92 id: request_id,
93 inner: Err(JsonRpcError::invalid_params(format!(
94 "Tool arguments exceed maximum size: {} bytes (limit: {} bytes)",
95 size, max_size
96 ))),
97 })
98 });
99 }
100 }
101
102 let fut = self.inner.call(req);
103 Box::pin(fut)
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use tower_mcp::protocol::McpRequest;
110
111 use super::{ValidationConfig, ValidationService};
112 use crate::test_util::{MockService, call_service};
113
114 #[tokio::test]
115 async fn test_validation_passes_small_arguments() {
116 let mock = MockService::with_tools(&["tool"]);
117 let config = ValidationConfig {
118 max_argument_size: Some(1024),
119 };
120 let mut svc = ValidationService::new(mock, config);
121
122 let resp = call_service(
123 &mut svc,
124 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
125 name: "tool".to_string(),
126 arguments: serde_json::json!({"key": "small"}),
127 meta: None,
128 task: None,
129 }),
130 )
131 .await;
132
133 assert!(resp.inner.is_ok(), "small args should pass validation");
134 }
135
136 #[tokio::test]
137 async fn test_validation_rejects_large_arguments() {
138 let mock = MockService::with_tools(&["tool"]);
139 let config = ValidationConfig {
140 max_argument_size: Some(10), };
142 let mut svc = ValidationService::new(mock, config);
143
144 let resp = call_service(
145 &mut svc,
146 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
147 name: "tool".to_string(),
148 arguments: serde_json::json!({"key": "this string is definitely longer than 10 bytes"}),
149 meta: None,
150 task: None,
151 }),
152 )
153 .await;
154
155 let err = resp.inner.unwrap_err();
156 assert!(
157 err.message.contains("exceed maximum size"),
158 "should mention size exceeded: {}",
159 err.message
160 );
161 }
162
163 #[tokio::test]
164 async fn test_validation_passes_non_tool_requests() {
165 let mock = MockService::with_tools(&["tool"]);
166 let config = ValidationConfig {
167 max_argument_size: Some(1),
168 };
169 let mut svc = ValidationService::new(mock, config);
170
171 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
172 assert!(resp.inner.is_ok(), "non-tool requests should pass");
173 }
174
175 #[tokio::test]
176 async fn test_validation_disabled_passes_everything() {
177 let mock = MockService::with_tools(&["tool"]);
178 let config = ValidationConfig {
179 max_argument_size: None,
180 };
181 let mut svc = ValidationService::new(mock, config);
182
183 let resp = call_service(
184 &mut svc,
185 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
186 name: "tool".to_string(),
187 arguments: serde_json::json!({"key": "any size is fine"}),
188 meta: None,
189 task: None,
190 }),
191 )
192 .await;
193
194 assert!(resp.inner.is_ok());
195 }
196}