1use std::{
2 future::Future,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7type PromptFuture = Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>>;
8
9use mcp_spec::{
10 content::Content,
11 handler::{PromptError, ResourceError, ToolError},
12 prompt::{Prompt, PromptMessage, PromptMessageRole},
13 protocol::{
14 CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcRequest,
15 JsonRpcResponse, ListPromptsResult, ListResourcesResult, ListToolsResult,
16 PromptsCapability, ReadResourceResult, ResourcesCapability, ServerCapabilities,
17 ToolsCapability,
18 },
19 ResourceContents,
20};
21use serde_json::Value;
22use tower_service::Service;
23
24use crate::{BoxError, RouterError};
25
26pub struct CapabilitiesBuilder {
28 tools: Option<ToolsCapability>,
29 prompts: Option<PromptsCapability>,
30 resources: Option<ResourcesCapability>,
31}
32
33impl Default for CapabilitiesBuilder {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl CapabilitiesBuilder {
40 pub fn new() -> Self {
41 Self {
42 tools: None,
43 prompts: None,
44 resources: None,
45 }
46 }
47
48 pub fn with_tools(mut self, list_changed: bool) -> Self {
50 self.tools = Some(ToolsCapability {
51 list_changed: Some(list_changed),
52 });
53 self
54 }
55
56 pub fn with_prompts(mut self, list_changed: bool) -> Self {
58 self.prompts = Some(PromptsCapability {
59 list_changed: Some(list_changed),
60 });
61 self
62 }
63
64 pub fn with_resources(mut self, subscribe: bool, list_changed: bool) -> Self {
66 self.resources = Some(ResourcesCapability {
67 subscribe: Some(subscribe),
68 list_changed: Some(list_changed),
69 });
70 self
71 }
72
73 pub fn build(self) -> ServerCapabilities {
75 ServerCapabilities {
77 tools: self.tools,
78 prompts: self.prompts,
79 resources: self.resources,
80 }
81 }
82}
83
84pub trait Router: Send + Sync + 'static {
85 fn name(&self) -> String;
86 fn instructions(&self) -> String;
88 fn capabilities(&self) -> ServerCapabilities;
89 fn list_tools(&self) -> Vec<mcp_spec::tool::Tool>;
90 fn call_tool(
91 &self,
92 tool_name: &str,
93 arguments: Value,
94 ) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>>;
95 fn list_resources(&self) -> Vec<mcp_spec::resource::Resource>;
96 fn read_resource(
97 &self,
98 uri: &str,
99 ) -> Pin<Box<dyn Future<Output = Result<String, ResourceError>> + Send + 'static>>;
100 fn list_prompts(&self) -> Vec<Prompt>;
101 fn get_prompt(&self, prompt_name: &str) -> PromptFuture;
102
103 fn create_response(&self, id: Option<u64>) -> JsonRpcResponse {
105 JsonRpcResponse {
106 jsonrpc: "2.0".to_string(),
107 id,
108 result: None,
109 error: None,
110 }
111 }
112
113 fn handle_initialize(
114 &self,
115 req: JsonRpcRequest,
116 ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
117 async move {
118 let result = InitializeResult {
119 protocol_version: "2024-11-05".to_string(),
120 capabilities: self.capabilities().clone(),
121 server_info: Implementation {
122 name: self.name(),
123 version: env!("CARGO_PKG_VERSION").to_string(),
124 },
125 instructions: Some(self.instructions()),
126 };
127
128 let mut response = self.create_response(req.id);
129 response.result =
130 Some(serde_json::to_value(result).map_err(|e| {
131 RouterError::Internal(format!("JSON serialization error: {}", e))
132 })?);
133
134 Ok(response)
135 }
136 }
137
138 fn handle_tools_list(
139 &self,
140 req: JsonRpcRequest,
141 ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
142 async move {
143 let tools = self.list_tools();
144
145 let result = ListToolsResult {
146 tools,
147 next_cursor: None,
148 };
149 let mut response = self.create_response(req.id);
150 response.result =
151 Some(serde_json::to_value(result).map_err(|e| {
152 RouterError::Internal(format!("JSON serialization error: {}", e))
153 })?);
154
155 Ok(response)
156 }
157 }
158
159 fn handle_tools_call(
160 &self,
161 req: JsonRpcRequest,
162 ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
163 async move {
164 let params = req
165 .params
166 .ok_or_else(|| RouterError::InvalidParams("Missing parameters".into()))?;
167
168 let name = params
169 .get("name")
170 .and_then(Value::as_str)
171 .ok_or_else(|| RouterError::InvalidParams("Missing tool name".into()))?;
172
173 let arguments = params.get("arguments").cloned().unwrap_or(Value::Null);
174
175 let result = match self.call_tool(name, arguments).await {
176 Ok(result) => CallToolResult {
177 content: result,
178 is_error: None,
179 },
180 Err(err) => CallToolResult {
181 content: vec![Content::text(err.to_string())],
182 is_error: Some(true),
183 },
184 };
185
186 let mut response = self.create_response(req.id);
187 response.result =
188 Some(serde_json::to_value(result).map_err(|e| {
189 RouterError::Internal(format!("JSON serialization error: {}", e))
190 })?);
191
192 Ok(response)
193 }
194 }
195
196 fn handle_resources_list(
197 &self,
198 req: JsonRpcRequest,
199 ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
200 async move {
201 let resources = self.list_resources();
202
203 let result = ListResourcesResult {
204 resources,
205 next_cursor: None,
206 };
207 let mut response = self.create_response(req.id);
208 response.result =
209 Some(serde_json::to_value(result).map_err(|e| {
210 RouterError::Internal(format!("JSON serialization error: {}", e))
211 })?);
212
213 Ok(response)
214 }
215 }
216
217 fn handle_resources_read(
218 &self,
219 req: JsonRpcRequest,
220 ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
221 async move {
222 let params = req
223 .params
224 .ok_or_else(|| RouterError::InvalidParams("Missing parameters".into()))?;
225
226 let uri = params
227 .get("uri")
228 .and_then(Value::as_str)
229 .ok_or_else(|| RouterError::InvalidParams("Missing resource URI".into()))?;
230
231 let contents = self.read_resource(uri).await.map_err(RouterError::from)?;
232
233 let result = ReadResourceResult {
234 contents: vec![ResourceContents::TextResourceContents {
235 uri: uri.to_string(),
236 mime_type: Some("text/plain".to_string()),
237 text: contents,
238 }],
239 };
240
241 let mut response = self.create_response(req.id);
242 response.result =
243 Some(serde_json::to_value(result).map_err(|e| {
244 RouterError::Internal(format!("JSON serialization error: {}", e))
245 })?);
246
247 Ok(response)
248 }
249 }
250
251 fn handle_prompts_list(
252 &self,
253 req: JsonRpcRequest,
254 ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
255 async move {
256 let prompts = self.list_prompts();
257
258 let result = ListPromptsResult { prompts };
259
260 let mut response = self.create_response(req.id);
261 response.result =
262 Some(serde_json::to_value(result).map_err(|e| {
263 RouterError::Internal(format!("JSON serialization error: {}", e))
264 })?);
265
266 Ok(response)
267 }
268 }
269
270 fn handle_prompts_get(
271 &self,
272 req: JsonRpcRequest,
273 ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
274 async move {
275 let params = req
277 .params
278 .ok_or_else(|| RouterError::InvalidParams("Missing parameters".into()))?;
279
280 let prompt_name = params
282 .get("name")
283 .and_then(Value::as_str)
284 .ok_or_else(|| RouterError::InvalidParams("Missing prompt name".into()))?;
285
286 let arguments = params
288 .get("arguments")
289 .and_then(Value::as_object)
290 .ok_or_else(|| RouterError::InvalidParams("Missing arguments object".into()))?;
291
292 let prompt = self
294 .list_prompts()
295 .into_iter()
296 .find(|p| p.name == prompt_name)
297 .ok_or_else(|| {
298 RouterError::PromptNotFound(format!("Prompt '{}' not found", prompt_name))
299 })?;
300
301 if let Some(args) = &prompt.arguments {
303 for arg in args {
304 if arg.required.is_some()
305 && arg.required.unwrap()
306 && (!arguments.contains_key(&arg.name)
307 || arguments
308 .get(&arg.name)
309 .and_then(Value::as_str)
310 .is_none_or(str::is_empty))
311 {
312 return Err(RouterError::InvalidParams(format!(
313 "Missing required argument: '{}'",
314 arg.name
315 )));
316 }
317 }
318 }
319
320 let description = self
322 .get_prompt(prompt_name)
323 .await
324 .map_err(|e| RouterError::Internal(e.to_string()))?;
325
326 for (key, value) in arguments.iter() {
333 if key.is_empty() || key.len() > 1000 {
335 return Err(RouterError::InvalidParams(
336 "Argument keys must be between 1-1000 characters".into(),
337 ));
338 }
339
340 let value_str = value.as_str().unwrap_or_default();
341 if value_str.len() > 1000 {
342 return Err(RouterError::InvalidParams(
343 "Argument values must not exceed 1000 characters".into(),
344 ));
345 }
346
347 let dangerous_patterns = ["../", "//", "\\\\", "<script>", "{{", "}}"];
349 for pattern in dangerous_patterns {
350 if key.contains(pattern) || value_str.contains(pattern) {
351 return Err(RouterError::InvalidParams(format!(
352 "Arguments contain potentially unsafe pattern: {}",
353 pattern
354 )));
355 }
356 }
357 }
358
359 if description.len() > 10000 {
361 return Err(RouterError::Internal(
362 "Prompt description exceeds maximum allowed length".into(),
363 ));
364 }
365
366 let mut description_filled = description.clone();
368
369 for (key, value) in arguments {
371 let placeholder = format!("{{{}}}", key);
372 description_filled =
373 description_filled.replace(&placeholder, value.as_str().unwrap_or_default());
374 }
375
376 let messages = vec![PromptMessage::new_text(
377 PromptMessageRole::User,
378 description_filled.to_string(),
379 )];
380
381 let mut response = self.create_response(req.id);
383 response.result = Some(
384 serde_json::to_value(GetPromptResult {
385 description: Some(description_filled),
386 messages,
387 })
388 .map_err(|e| RouterError::Internal(format!("JSON serialization error: {}", e)))?,
389 );
390 Ok(response)
391 }
392 }
393}
394
395pub struct RouterService<T>(pub T);
396
397impl<T> Service<JsonRpcRequest> for RouterService<T>
398where
399 T: Router + Clone + Send + Sync + 'static,
400{
401 type Response = JsonRpcResponse;
402 type Error = BoxError;
403 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
404
405 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
406 Poll::Ready(Ok(()))
407 }
408
409 fn call(&mut self, req: JsonRpcRequest) -> Self::Future {
410 let this = self.0.clone();
411
412 Box::pin(async move {
413 let result = match req.method.as_str() {
414 "initialize" => this.handle_initialize(req).await,
415 "tools/list" => this.handle_tools_list(req).await,
416 "tools/call" => this.handle_tools_call(req).await,
417 "resources/list" => this.handle_resources_list(req).await,
418 "resources/read" => this.handle_resources_read(req).await,
419 "prompts/list" => this.handle_prompts_list(req).await,
420 "prompts/get" => this.handle_prompts_get(req).await,
421 _ => {
422 let mut response = this.create_response(req.id);
423 response.error = Some(RouterError::MethodNotFound(req.method).into());
424 Ok(response)
425 }
426 };
427
428 result.map_err(BoxError::from)
429 })
430 }
431}