1use std::{collections::HashSet, marker::PhantomData, pin::pin, sync::Arc};
4
5use futures::{
6 SinkExt,
7 channel::{mpsc, oneshot},
8 future::{BoxFuture, Either},
9};
10use futures_concurrency::future::TryJoin;
11use rustc_hash::FxHashMap;
12
13#[derive(Clone, Debug)]
18pub enum EnabledTools {
19 DenyList(HashSet<String>),
21 AllowList(HashSet<String>),
23}
24
25impl Default for EnabledTools {
26 fn default() -> Self {
27 EnabledTools::DenyList(HashSet::new())
28 }
29}
30
31impl EnabledTools {
32 #[must_use]
34 pub fn is_enabled(&self, name: &str) -> bool {
35 match self {
36 EnabledTools::DenyList(deny) => !deny.contains(name),
37 EnabledTools::AllowList(allow) => allow.contains(name),
38 }
39 }
40}
41use rmcp::{
42 ErrorData, ServerHandler,
43 handler::server::tool::{schema_for_output, schema_for_type},
44 model::{CallToolResult, ListToolsResult, Tool},
45};
46use schemars::JsonSchema;
47use serde::{Serialize, de::DeserializeOwned};
48use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
49
50use super::{McpConnectionTo, McpTool};
51use crate::{
52 ByteStreams, ConnectTo, DynConnectTo,
53 jsonrpc::run::{ChainRun, NullRun, RunWithConnectionTo},
54 mcp_server::{
55 McpServer, McpServerConnect,
56 responder::{ToolCall, ToolFnMutResponder, ToolFnResponder},
57 },
58 role::{self, Role},
59};
60
61#[derive(Debug)]
81pub struct McpServerBuilder<Counterpart: Role, Responder>
82where
83 Responder: RunWithConnectionTo<Counterpart>,
84{
85 phantom: PhantomData<Counterpart>,
86 name: String,
87 data: McpServerData<Counterpart>,
88 responder: Responder,
89}
90
91#[derive(Debug)]
92struct McpServerData<Counterpart: Role> {
93 instructions: Option<String>,
94 tool_models: Vec<rmcp::model::Tool>,
95 tools: FxHashMap<String, RegisteredTool<Counterpart>>,
96 enabled_tools: EnabledTools,
97}
98
99struct RegisteredTool<Counterpart: Role> {
101 tool: Arc<dyn ErasedMcpTool<Counterpart>>,
102 has_structured_output: bool,
104}
105
106impl<Counterpart: Role + std::fmt::Debug> std::fmt::Debug for RegisteredTool<Counterpart> {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 f.debug_struct("RegisteredTool")
109 .field("has_structured_output", &self.has_structured_output)
110 .finish_non_exhaustive()
111 }
112}
113
114impl<Host: Role> Default for McpServerData<Host> {
115 fn default() -> Self {
116 Self {
117 instructions: None,
118 tool_models: Vec::new(),
119 tools: FxHashMap::default(),
120 enabled_tools: EnabledTools::default(),
121 }
122 }
123}
124
125impl<Counterpart: Role> McpServerBuilder<Counterpart, NullRun> {
126 pub(super) fn new(name: String) -> Self {
127 Self {
128 name,
129 phantom: PhantomData,
130 data: McpServerData::default(),
131 responder: NullRun,
132 }
133 }
134}
135
136impl<Counterpart: Role, Responder> McpServerBuilder<Counterpart, Responder>
137where
138 Responder: RunWithConnectionTo<Counterpart>,
139{
140 #[must_use]
142 pub fn instructions(mut self, instructions: impl ToString) -> Self {
143 self.data.instructions = Some(instructions.to_string());
144 self
145 }
146
147 #[must_use]
149 pub fn tool(mut self, tool: impl McpTool<Counterpart> + 'static) -> Self {
150 let tool_model = make_tool_model(&tool);
151 let has_structured_output = tool_model.output_schema.is_some();
152 self.data.tool_models.push(tool_model);
153 self.data.tools.insert(
154 tool.name(),
155 RegisteredTool {
156 tool: make_erased_mcp_tool(tool),
157 has_structured_output,
158 },
159 );
160 self
161 }
162
163 #[must_use]
166 pub fn disable_all_tools(mut self) -> Self {
167 self.data.enabled_tools = EnabledTools::AllowList(HashSet::new());
168 self
169 }
170
171 #[must_use]
174 pub fn enable_all_tools(mut self) -> Self {
175 self.data.enabled_tools = EnabledTools::DenyList(HashSet::new());
176 self
177 }
178
179 pub fn disable_tool(mut self, name: &str) -> Result<Self, crate::Error> {
183 if !self.data.tools.contains_key(name) {
184 return Err(crate::Error::invalid_request().data(format!("unknown tool: {name}")));
185 }
186 match &mut self.data.enabled_tools {
187 EnabledTools::DenyList(deny) => {
188 deny.insert(name.to_string());
189 }
190 EnabledTools::AllowList(allow) => {
191 allow.remove(name);
192 }
193 }
194 Ok(self)
195 }
196
197 pub fn enable_tool(mut self, name: &str) -> Result<Self, crate::Error> {
201 if !self.data.tools.contains_key(name) {
202 return Err(crate::Error::invalid_request().data(format!("unknown tool: {name}")));
203 }
204 match &mut self.data.enabled_tools {
205 EnabledTools::DenyList(deny) => {
206 deny.remove(name);
207 }
208 EnabledTools::AllowList(allow) => {
209 allow.insert(name.to_string());
210 }
211 }
212 Ok(self)
213 }
214
215 fn tool_with_responder(
218 self,
219 tool: impl McpTool<Counterpart> + 'static,
220 tool_responder: impl RunWithConnectionTo<Counterpart>,
221 ) -> McpServerBuilder<Counterpart, impl RunWithConnectionTo<Counterpart>> {
222 let this = self.tool(tool);
223 McpServerBuilder {
224 phantom: PhantomData,
225 name: this.name,
226 data: this.data,
227 responder: ChainRun::new(this.responder, tool_responder),
228 }
229 }
230
231 pub fn tool_fn_mut<P, Ret, F>(
254 self,
255 name: impl ToString,
256 description: impl ToString,
257 func: F,
258 tool_future_hack: impl for<'a> Fn(
259 &'a mut F,
260 P,
261 McpConnectionTo<Counterpart>,
262 ) -> BoxFuture<'a, Result<Ret, crate::Error>>
263 + Send
264 + 'static,
265 ) -> McpServerBuilder<Counterpart, impl RunWithConnectionTo<Counterpart>>
266 where
267 P: JsonSchema + DeserializeOwned + 'static + Send,
268 Ret: JsonSchema + Serialize + 'static + Send,
269 F: AsyncFnMut(P, McpConnectionTo<Counterpart>) -> Result<Ret, crate::Error> + Send,
270 {
271 let (call_tx, call_rx) = mpsc::channel(128);
272 self.tool_with_responder(
273 ToolFnTool {
274 name: name.to_string(),
275 description: description.to_string(),
276 call_tx,
277 },
278 ToolFnMutResponder {
279 func,
280 call_rx,
281 tool_future_fn: Box::new(tool_future_hack),
282 },
283 )
284 }
285
286 pub fn tool_fn<P, Ret, F>(
307 self,
308 name: impl ToString,
309 description: impl ToString,
310 func: F,
311 tool_future_hack: impl for<'a> Fn(
312 &'a F,
313 P,
314 McpConnectionTo<Counterpart>,
315 ) -> BoxFuture<'a, Result<Ret, crate::Error>>
316 + Send
317 + Sync
318 + 'static,
319 ) -> McpServerBuilder<Counterpart, impl RunWithConnectionTo<Counterpart>>
320 where
321 P: JsonSchema + DeserializeOwned + 'static + Send,
322 Ret: JsonSchema + Serialize + 'static + Send,
323 F: AsyncFn(P, McpConnectionTo<Counterpart>) -> Result<Ret, crate::Error>
324 + Send
325 + Sync
326 + 'static,
327 {
328 let (call_tx, call_rx) = mpsc::channel(128);
329 self.tool_with_responder(
330 ToolFnTool {
331 name: name.to_string(),
332 description: description.to_string(),
333 call_tx,
334 },
335 ToolFnResponder {
336 func,
337 call_rx,
338 tool_future_fn: Box::new(tool_future_hack),
339 },
340 )
341 }
342
343 pub fn build(self) -> McpServer<Counterpart, Responder> {
348 McpServer::new(
349 McpServerBuilt {
350 name: self.name,
351 data: Arc::new(self.data),
352 },
353 self.responder,
354 )
355 }
356}
357
358struct McpServerBuilt<Counterpart: Role> {
359 name: String,
360 data: Arc<McpServerData<Counterpart>>,
361}
362
363impl<Counterpart: Role> McpServerConnect<Counterpart> for McpServerBuilt<Counterpart> {
364 fn name(&self) -> String {
365 self.name.clone()
366 }
367
368 fn connect(
369 &self,
370 mcp_connection: McpConnectionTo<Counterpart>,
371 ) -> DynConnectTo<role::mcp::Client> {
372 DynConnectTo::new(McpServerConnection {
373 data: self.data.clone(),
374 mcp_connection,
375 })
376 }
377}
378
379pub(crate) struct McpServerConnection<Counterpart: Role> {
381 data: Arc<McpServerData<Counterpart>>,
382 mcp_connection: McpConnectionTo<Counterpart>,
383}
384
385impl<Counterpart: Role> ConnectTo<role::mcp::Client> for McpServerConnection<Counterpart> {
386 async fn connect_to(
387 self,
388 client: impl ConnectTo<role::mcp::Server>,
389 ) -> Result<(), crate::Error> {
390 let (mcp_server_stream, mcp_client_stream) = tokio::io::duplex(8192);
392 let (mcp_server_read, mcp_server_write) = tokio::io::split(mcp_server_stream);
393 let (mcp_client_read, mcp_client_write) = tokio::io::split(mcp_client_stream);
394
395 let run_client = async {
396 let byte_streams =
398 ByteStreams::new(mcp_client_write.compat_write(), mcp_client_read.compat());
399 drop(
400 <ByteStreams<_, _> as ConnectTo<role::mcp::Client>>::connect_to(
401 byte_streams,
402 client,
403 )
404 .await,
405 );
406 Ok(())
407 };
408
409 let run_server = async {
410 let running_server = rmcp::ServiceExt::serve(self, (mcp_server_read, mcp_server_write))
412 .await
413 .map_err(crate::Error::into_internal_error)?;
414
415 running_server
417 .waiting()
418 .await
419 .map(|_quit_reason| ())
420 .map_err(crate::Error::into_internal_error)
421 };
422
423 (run_client, run_server).try_join().await?;
424 Ok(())
425 }
426}
427
428impl<R: Role> ServerHandler for McpServerConnection<R> {
429 async fn call_tool(
430 &self,
431 request: rmcp::model::CallToolRequestParams,
432 context: rmcp::service::RequestContext<rmcp::RoleServer>,
433 ) -> Result<CallToolResult, ErrorData> {
434 let Some(registered) = self.data.tools.get(&request.name[..]) else {
436 return Err(rmcp::model::ErrorData::invalid_params(
437 format!("tool `{}` not found", request.name),
438 None,
439 ));
440 };
441
442 if !self.data.enabled_tools.is_enabled(&request.name) {
444 return Err(rmcp::model::ErrorData::invalid_params(
445 format!("tool `{}` not found", request.name),
446 None,
447 ));
448 }
449
450 let serde_value = serde_json::to_value(request.arguments).expect("valid json");
452
453 let has_structured_output = registered.has_structured_output;
455 match futures::future::select(
456 registered
457 .tool
458 .call_tool(serde_value, self.mcp_connection.clone()),
459 pin!(context.ct.cancelled()),
460 )
461 .await
462 {
463 Either::Left((m, _)) => match m {
465 Ok(result) => {
466 if has_structured_output {
468 Ok(CallToolResult::structured(result))
469 } else {
470 Ok(CallToolResult::success(vec![rmcp::model::Content::text(
471 result.to_string(),
472 )]))
473 }
474 }
475 Err(error) => Err(to_rmcp_error(error)),
476 },
477
478 Either::Right(((), _)) => {
480 Err(rmcp::ErrorData::internal_error("operation cancelled", None))
481 }
482 }
483 }
484
485 async fn list_tools(
486 &self,
487 _request: Option<rmcp::model::PaginatedRequestParams>,
488 _context: rmcp::service::RequestContext<rmcp::RoleServer>,
489 ) -> Result<rmcp::model::ListToolsResult, ErrorData> {
490 let tools: Vec<_> = self
492 .data
493 .tool_models
494 .iter()
495 .filter(|t| self.data.enabled_tools.is_enabled(&t.name))
496 .cloned()
497 .collect();
498 Ok(ListToolsResult::with_all_items(tools))
499 }
500
501 fn get_info(&self) -> rmcp::model::ServerInfo {
502 let base = rmcp::model::ServerInfo::new(
504 rmcp::model::ServerCapabilities::builder()
505 .enable_tools()
506 .build(),
507 )
508 .with_server_info(rmcp::model::Implementation::default())
509 .with_protocol_version(rmcp::model::ProtocolVersion::default());
510
511 if let Some(instr) = self.data.instructions.clone() {
512 base.with_instructions(instr)
513 } else {
514 base
515 }
516 }
517}
518
519trait ErasedMcpTool<Counterpart: Role>: Send + Sync {
521 fn call_tool(
522 &self,
523 input: serde_json::Value,
524 connection: McpConnectionTo<Counterpart>,
525 ) -> BoxFuture<'_, Result<serde_json::Value, crate::Error>>;
526}
527
528fn make_tool_model<R: Role, M: McpTool<R>>(tool: &M) -> Tool {
530 let mut tool = rmcp::model::Tool::new(
531 tool.name(),
532 tool.description(),
533 schema_for_type::<M::Input>(),
534 )
535 .with_execution(rmcp::model::ToolExecution::new());
536
537 if let Ok(schema) = schema_for_output::<M::Output>() {
538 tool = tool.with_raw_output_schema(schema);
542 }
543
544 tool
545}
546
547fn make_erased_mcp_tool<'s, R: Role, M: McpTool<R> + 's>(
549 tool: M,
550) -> Arc<dyn ErasedMcpTool<R> + 's> {
551 struct ErasedMcpToolImpl<M> {
552 tool: M,
553 }
554
555 impl<R, M> ErasedMcpTool<R> for ErasedMcpToolImpl<M>
556 where
557 R: Role,
558 M: McpTool<R>,
559 {
560 fn call_tool(
561 &self,
562 input: serde_json::Value,
563 context: McpConnectionTo<R>,
564 ) -> BoxFuture<'_, Result<serde_json::Value, crate::Error>> {
565 Box::pin(async move {
566 let input = serde_json::from_value(input).map_err(crate::util::internal_error)?;
567 serde_json::to_value(self.tool.call_tool(input, context).await?)
568 .map_err(crate::util::internal_error)
569 })
570 }
571 }
572
573 Arc::new(ErasedMcpToolImpl { tool })
574}
575
576fn to_rmcp_error(error: crate::Error) -> rmcp::ErrorData {
578 rmcp::ErrorData {
579 code: rmcp::model::ErrorCode(error.code.into()),
580 message: error.message.into(),
581 data: error.data,
582 }
583}
584
585struct ToolFnTool<P, Ret, R: Role> {
588 name: String,
589 description: String,
590 call_tx: mpsc::Sender<ToolCall<P, Ret, R>>,
591}
592
593impl<P, Ret, R> McpTool<R> for ToolFnTool<P, Ret, R>
594where
595 R: Role,
596 P: JsonSchema + DeserializeOwned + 'static + Send,
597 Ret: JsonSchema + Serialize + 'static + Send,
598{
599 type Input = P;
600 type Output = Ret;
601
602 fn name(&self) -> String {
603 self.name.clone()
604 }
605
606 fn description(&self) -> String {
607 self.description.clone()
608 }
609
610 async fn call_tool(
611 &self,
612 params: P,
613 mcp_connection: McpConnectionTo<R>,
614 ) -> Result<Ret, crate::Error> {
615 let (result_tx, result_rx) = oneshot::channel();
616
617 self.call_tx
618 .clone()
619 .send(ToolCall {
620 params,
621 mcp_connection,
622 result_tx,
623 })
624 .await
625 .map_err(crate::util::internal_error)?;
626
627 result_rx.await.map_err(crate::util::internal_error)?
628 }
629}