1use std::convert::Infallible;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use tower::Service;
12
13use tower_mcp::protocol::McpRequest;
14use tower_mcp::{RouterRequest, RouterResponse};
15use tower_mcp_types::JsonRpcError;
16
17#[derive(Clone)]
19pub struct ValidationConfig {
20 pub max_argument_size: Option<usize>,
22}
23
24#[derive(Clone)]
26pub struct ValidationService<S> {
27 inner: S,
28 config: Arc<ValidationConfig>,
29}
30
31impl<S> ValidationService<S> {
32 pub fn new(inner: S, config: ValidationConfig) -> Self {
34 Self {
35 inner,
36 config: Arc::new(config),
37 }
38 }
39}
40
41impl<S> Service<RouterRequest> for ValidationService<S>
42where
43 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
44 + Clone
45 + Send
46 + 'static,
47 S::Future: Send,
48{
49 type Response = RouterResponse;
50 type Error = Infallible;
51 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
52
53 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
54 self.inner.poll_ready(cx)
55 }
56
57 fn call(&mut self, req: RouterRequest) -> Self::Future {
58 let config = Arc::clone(&self.config);
59 let request_id = req.id.clone();
60
61 if let McpRequest::CallTool(ref params) = req.inner
63 && let Some(max_size) = config.max_argument_size
64 {
65 let size = serde_json::to_string(¶ms.arguments)
66 .map(|s| s.len())
67 .unwrap_or(0);
68 if size > max_size {
69 return Box::pin(async move {
70 Ok(RouterResponse {
71 id: request_id,
72 inner: Err(JsonRpcError::invalid_params(format!(
73 "Tool arguments exceed maximum size: {} bytes (limit: {} bytes)",
74 size, max_size
75 ))),
76 })
77 });
78 }
79 }
80
81 let fut = self.inner.call(req);
82 Box::pin(fut)
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use tower_mcp::protocol::McpRequest;
89
90 use super::{ValidationConfig, ValidationService};
91 use crate::test_util::{MockService, call_service};
92
93 #[tokio::test]
94 async fn test_validation_passes_small_arguments() {
95 let mock = MockService::with_tools(&["tool"]);
96 let config = ValidationConfig {
97 max_argument_size: Some(1024),
98 };
99 let mut svc = ValidationService::new(mock, config);
100
101 let resp = call_service(
102 &mut svc,
103 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
104 name: "tool".to_string(),
105 arguments: serde_json::json!({"key": "small"}),
106 meta: None,
107 task: None,
108 }),
109 )
110 .await;
111
112 assert!(resp.inner.is_ok(), "small args should pass validation");
113 }
114
115 #[tokio::test]
116 async fn test_validation_rejects_large_arguments() {
117 let mock = MockService::with_tools(&["tool"]);
118 let config = ValidationConfig {
119 max_argument_size: Some(10), };
121 let mut svc = ValidationService::new(mock, config);
122
123 let resp = call_service(
124 &mut svc,
125 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
126 name: "tool".to_string(),
127 arguments: serde_json::json!({"key": "this string is definitely longer than 10 bytes"}),
128 meta: None,
129 task: None,
130 }),
131 )
132 .await;
133
134 let err = resp.inner.unwrap_err();
135 assert!(
136 err.message.contains("exceed maximum size"),
137 "should mention size exceeded: {}",
138 err.message
139 );
140 }
141
142 #[tokio::test]
143 async fn test_validation_passes_non_tool_requests() {
144 let mock = MockService::with_tools(&["tool"]);
145 let config = ValidationConfig {
146 max_argument_size: Some(1),
147 };
148 let mut svc = ValidationService::new(mock, config);
149
150 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
151 assert!(resp.inner.is_ok(), "non-tool requests should pass");
152 }
153
154 #[tokio::test]
155 async fn test_validation_disabled_passes_everything() {
156 let mock = MockService::with_tools(&["tool"]);
157 let config = ValidationConfig {
158 max_argument_size: None,
159 };
160 let mut svc = ValidationService::new(mock, config);
161
162 let resp = call_service(
163 &mut svc,
164 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
165 name: "tool".to_string(),
166 arguments: serde_json::json!({"key": "any size is fine"}),
167 meta: None,
168 task: None,
169 }),
170 )
171 .await;
172
173 assert!(resp.inner.is_ok());
174 }
175}