Skip to main content

agent_client_protocol/mcp_server/
builder.rs

1//! MCP server builder for creating MCP servers.
2
3use std::{collections::HashSet, marker::PhantomData, pin::pin, sync::Arc};
4
5use futures::{
6    SinkExt,
7    channel::{mpsc, oneshot},
8    future::{BoxFuture, Either},
9};
10use futures_concurrency::future::TryJoin;
11use rustc_hash::FxHashMap;
12
13/// Tracks which tools are enabled.
14///
15/// - `DenyList`: All tools enabled except those in the set (default)
16/// - `AllowList`: Only tools in the set are enabled
17#[derive(Clone, Debug)]
18pub enum EnabledTools {
19    /// All tools enabled except those in the deny set.
20    DenyList(HashSet<String>),
21    /// Only tools in the allow set are enabled.
22    AllowList(HashSet<String>),
23}
24
25impl Default for EnabledTools {
26    fn default() -> Self {
27        EnabledTools::DenyList(HashSet::new())
28    }
29}
30
31impl EnabledTools {
32    /// Check if a tool is enabled.
33    #[must_use]
34    pub fn is_enabled(&self, name: &str) -> bool {
35        match self {
36            EnabledTools::DenyList(deny) => !deny.contains(name),
37            EnabledTools::AllowList(allow) => allow.contains(name),
38        }
39    }
40}
41use rmcp::{
42    ErrorData, ServerHandler,
43    handler::server::tool::{schema_for_output, schema_for_type},
44    model::{CallToolResult, ListToolsResult, Tool},
45};
46use schemars::JsonSchema;
47use serde::{Serialize, de::DeserializeOwned};
48use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
49
50use super::{McpConnectionTo, McpTool};
51use crate::{
52    ByteStreams, ConnectTo, DynConnectTo,
53    jsonrpc::run::{ChainRun, NullRun, RunWithConnectionTo},
54    mcp_server::{
55        McpServer, McpServerConnect,
56        responder::{ToolCall, ToolFnMutResponder, ToolFnResponder},
57    },
58    role::{self, Role},
59};
60
61/// Builder for creating MCP servers with tools.
62///
63/// Use [`McpServer::builder`] to create a new builder, then chain methods to
64/// configure the server and call [`build`](Self::build) to create the server.
65///
66/// # Example
67///
68/// ```rust,ignore
69/// let server = McpServer::builder("my-server".to_string())
70///     .instructions("A helpful assistant")
71///     .tool(EchoTool)
72///     .tool_fn(
73///         "greet",
74///         "Greet someone by name",
75///         async |input: GreetInput, _cx| Ok(format!("Hello, {}!", input.name)),
76///         agent_client_protocol::tool_fn!(),
77///     )
78///     .build();
79/// ```
80#[derive(Debug)]
81pub struct McpServerBuilder<Counterpart: Role, Responder>
82where
83    Responder: RunWithConnectionTo<Counterpart>,
84{
85    phantom: PhantomData<Counterpart>,
86    name: String,
87    data: McpServerData<Counterpart>,
88    responder: Responder,
89}
90
91#[derive(Debug)]
92struct McpServerData<Counterpart: Role> {
93    instructions: Option<String>,
94    tool_models: Vec<rmcp::model::Tool>,
95    tools: FxHashMap<String, RegisteredTool<Counterpart>>,
96    enabled_tools: EnabledTools,
97}
98
99/// A registered tool with its metadata.
100struct RegisteredTool<Counterpart: Role> {
101    tool: Arc<dyn ErasedMcpTool<Counterpart>>,
102    /// Whether this tool returns structured output (i.e., has an output_schema).
103    has_structured_output: bool,
104}
105
106impl<Counterpart: Role + std::fmt::Debug> std::fmt::Debug for RegisteredTool<Counterpart> {
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        f.debug_struct("RegisteredTool")
109            .field("has_structured_output", &self.has_structured_output)
110            .finish_non_exhaustive()
111    }
112}
113
114impl<Host: Role> Default for McpServerData<Host> {
115    fn default() -> Self {
116        Self {
117            instructions: None,
118            tool_models: Vec::new(),
119            tools: FxHashMap::default(),
120            enabled_tools: EnabledTools::default(),
121        }
122    }
123}
124
125impl<Counterpart: Role> McpServerBuilder<Counterpart, NullRun> {
126    pub(super) fn new(name: String) -> Self {
127        Self {
128            name,
129            phantom: PhantomData,
130            data: McpServerData::default(),
131            responder: NullRun,
132        }
133    }
134}
135
136impl<Counterpart: Role, Responder> McpServerBuilder<Counterpart, Responder>
137where
138    Responder: RunWithConnectionTo<Counterpart>,
139{
140    /// Set the server instructions that are provided to the client.
141    #[must_use]
142    pub fn instructions(mut self, instructions: impl ToString) -> Self {
143        self.data.instructions = Some(instructions.to_string());
144        self
145    }
146
147    /// Add a tool to the server.
148    #[must_use]
149    pub fn tool(mut self, tool: impl McpTool<Counterpart> + 'static) -> Self {
150        let tool_model = make_tool_model(&tool);
151        let has_structured_output = tool_model.output_schema.is_some();
152        self.data.tool_models.push(tool_model);
153        self.data.tools.insert(
154            tool.name(),
155            RegisteredTool {
156                tool: make_erased_mcp_tool(tool),
157                has_structured_output,
158            },
159        );
160        self
161    }
162
163    /// Disable all tools. After calling this, only tools explicitly enabled
164    /// with [`enable_tool`](Self::enable_tool) will be available.
165    #[must_use]
166    pub fn disable_all_tools(mut self) -> Self {
167        self.data.enabled_tools = EnabledTools::AllowList(HashSet::new());
168        self
169    }
170
171    /// Enable all tools. After calling this, all tools will be available
172    /// except those explicitly disabled with [`disable_tool`](Self::disable_tool).
173    #[must_use]
174    pub fn enable_all_tools(mut self) -> Self {
175        self.data.enabled_tools = EnabledTools::DenyList(HashSet::new());
176        self
177    }
178
179    /// Disable a specific tool by name.
180    ///
181    /// Returns an error if the tool is not registered.
182    pub fn disable_tool(mut self, name: &str) -> Result<Self, crate::Error> {
183        if !self.data.tools.contains_key(name) {
184            return Err(crate::Error::invalid_request().data(format!("unknown tool: {name}")));
185        }
186        match &mut self.data.enabled_tools {
187            EnabledTools::DenyList(deny) => {
188                deny.insert(name.to_string());
189            }
190            EnabledTools::AllowList(allow) => {
191                allow.remove(name);
192            }
193        }
194        Ok(self)
195    }
196
197    /// Enable a specific tool by name.
198    ///
199    /// Returns an error if the tool is not registered.
200    pub fn enable_tool(mut self, name: &str) -> Result<Self, crate::Error> {
201        if !self.data.tools.contains_key(name) {
202            return Err(crate::Error::invalid_request().data(format!("unknown tool: {name}")));
203        }
204        match &mut self.data.enabled_tools {
205            EnabledTools::DenyList(deny) => {
206                deny.remove(name);
207            }
208            EnabledTools::AllowList(allow) => {
209                allow.insert(name.to_string());
210            }
211        }
212        Ok(self)
213    }
214
215    /// Private fn: adds the tool but also adds a responder that will be
216    /// run while the MCP server is active.
217    fn tool_with_responder(
218        self,
219        tool: impl McpTool<Counterpart> + 'static,
220        tool_responder: impl RunWithConnectionTo<Counterpart>,
221    ) -> McpServerBuilder<Counterpart, impl RunWithConnectionTo<Counterpart>> {
222        let this = self.tool(tool);
223        McpServerBuilder {
224            phantom: PhantomData,
225            name: this.name,
226            data: this.data,
227            responder: ChainRun::new(this.responder, tool_responder),
228        }
229    }
230
231    /// Convenience wrapper for defining a "single-threaded" tool without having to create a struct.
232    /// By "single-threaded", we mean that only one invocation of the tool can be running at a time.
233    /// Typically agents invoke a tool once per session and then block waiting for the result,
234    /// so this is fine, but they could attempt to run multiple invocations concurrently, in which
235    /// case those invocations would be serialized.
236    ///
237    /// # Parameters
238    ///
239    /// * `name`: The name of the tool.
240    /// * `description`: The description of the tool.
241    /// * `func`: The function that implements the tool. Use an async closure like `async |args, cx| { .. }`.
242    ///
243    /// # Examples
244    ///
245    /// ```rust,ignore
246    /// McpServer::builder("my-server")
247    ///     .tool_fn_mut(
248    ///         "greet",
249    ///         "Greet someone by name",
250    ///         async |input: GreetInput, _cx| Ok(format!("Hello, {}!", input.name)),
251    ///     )
252    /// ```
253    pub fn tool_fn_mut<P, Ret, F>(
254        self,
255        name: impl ToString,
256        description: impl ToString,
257        func: F,
258        tool_future_hack: impl for<'a> Fn(
259            &'a mut F,
260            P,
261            McpConnectionTo<Counterpart>,
262        ) -> BoxFuture<'a, Result<Ret, crate::Error>>
263        + Send
264        + 'static,
265    ) -> McpServerBuilder<Counterpart, impl RunWithConnectionTo<Counterpart>>
266    where
267        P: JsonSchema + DeserializeOwned + 'static + Send,
268        Ret: JsonSchema + Serialize + 'static + Send,
269        F: AsyncFnMut(P, McpConnectionTo<Counterpart>) -> Result<Ret, crate::Error> + Send,
270    {
271        let (call_tx, call_rx) = mpsc::channel(128);
272        self.tool_with_responder(
273            ToolFnTool {
274                name: name.to_string(),
275                description: description.to_string(),
276                call_tx,
277            },
278            ToolFnMutResponder {
279                func,
280                call_rx,
281                tool_future_fn: Box::new(tool_future_hack),
282            },
283        )
284    }
285
286    /// Convenience wrapper for defining a stateless tool that can run concurrently.
287    /// Unlike [`tool_fn_mut`](Self::tool_fn_mut), multiple invocations of this tool can run
288    /// at the same time since the function is `Fn` rather than `FnMut`.
289    ///
290    /// # Parameters
291    ///
292    /// * `name`: The name of the tool.
293    /// * `description`: The description of the tool.
294    /// * `func`: The function that implements the tool. Use an async closure like `async |args, cx| { .. }`.
295    ///
296    /// # Examples
297    ///
298    /// ```rust,ignore
299    /// McpServer::builder("my-server")
300    ///     .tool_fn(
301    ///         "greet",
302    ///         "Greet someone by name",
303    ///         async |input: GreetInput, _cx| Ok(format!("Hello, {}!", input.name)),
304    ///     )
305    /// ```
306    pub fn tool_fn<P, Ret, F>(
307        self,
308        name: impl ToString,
309        description: impl ToString,
310        func: F,
311        tool_future_hack: impl for<'a> Fn(
312            &'a F,
313            P,
314            McpConnectionTo<Counterpart>,
315        ) -> BoxFuture<'a, Result<Ret, crate::Error>>
316        + Send
317        + Sync
318        + 'static,
319    ) -> McpServerBuilder<Counterpart, impl RunWithConnectionTo<Counterpart>>
320    where
321        P: JsonSchema + DeserializeOwned + 'static + Send,
322        Ret: JsonSchema + Serialize + 'static + Send,
323        F: AsyncFn(P, McpConnectionTo<Counterpart>) -> Result<Ret, crate::Error>
324            + Send
325            + Sync
326            + 'static,
327    {
328        let (call_tx, call_rx) = mpsc::channel(128);
329        self.tool_with_responder(
330            ToolFnTool {
331                name: name.to_string(),
332                description: description.to_string(),
333                call_tx,
334            },
335            ToolFnResponder {
336                func,
337                call_rx,
338                tool_future_fn: Box::new(tool_future_hack),
339            },
340        )
341    }
342
343    /// Create an MCP server from this builder.
344    ///
345    /// This builder can be attached to new sessions (see [`SessionBuilder::with_mcp_server`](`crate::SessionBuilder::with_mcp_server`))
346    /// or served up as part of a proxy (see [`Builder::with_mcp_server`](`crate::Builder::with_mcp_server`)).
347    pub fn build(self) -> McpServer<Counterpart, Responder> {
348        McpServer::new(
349            McpServerBuilt {
350                name: self.name,
351                data: Arc::new(self.data),
352            },
353            self.responder,
354        )
355    }
356}
357
358struct McpServerBuilt<Counterpart: Role> {
359    name: String,
360    data: Arc<McpServerData<Counterpart>>,
361}
362
363impl<Counterpart: Role> McpServerConnect<Counterpart> for McpServerBuilt<Counterpart> {
364    fn name(&self) -> String {
365        self.name.clone()
366    }
367
368    fn connect(
369        &self,
370        mcp_connection: McpConnectionTo<Counterpart>,
371    ) -> DynConnectTo<role::mcp::Client> {
372        DynConnectTo::new(McpServerConnection {
373            data: self.data.clone(),
374            mcp_connection,
375        })
376    }
377}
378
379/// An MCP server instance connected to the ACP framework.
380pub(crate) struct McpServerConnection<Counterpart: Role> {
381    data: Arc<McpServerData<Counterpart>>,
382    mcp_connection: McpConnectionTo<Counterpart>,
383}
384
385impl<Counterpart: Role> ConnectTo<role::mcp::Client> for McpServerConnection<Counterpart> {
386    async fn connect_to(
387        self,
388        client: impl ConnectTo<role::mcp::Server>,
389    ) -> Result<(), crate::Error> {
390        // Create tokio byte streams that rmcp expects
391        let (mcp_server_stream, mcp_client_stream) = tokio::io::duplex(8192);
392        let (mcp_server_read, mcp_server_write) = tokio::io::split(mcp_server_stream);
393        let (mcp_client_read, mcp_client_write) = tokio::io::split(mcp_client_stream);
394
395        let run_client = async {
396            // Connect byte_streams to the provided client
397            let byte_streams =
398                ByteStreams::new(mcp_client_write.compat_write(), mcp_client_read.compat());
399            drop(
400                <ByteStreams<_, _> as ConnectTo<role::mcp::Client>>::connect_to(
401                    byte_streams,
402                    client,
403                )
404                .await,
405            );
406            Ok(())
407        };
408
409        let run_server = async {
410            // Run the rmcp server with the server side of the duplex stream
411            let running_server = rmcp::ServiceExt::serve(self, (mcp_server_read, mcp_server_write))
412                .await
413                .map_err(crate::Error::into_internal_error)?;
414
415            // Wait for the server to finish
416            running_server
417                .waiting()
418                .await
419                .map(|_quit_reason| ())
420                .map_err(crate::Error::into_internal_error)
421        };
422
423        (run_client, run_server).try_join().await?;
424        Ok(())
425    }
426}
427
428impl<R: Role> ServerHandler for McpServerConnection<R> {
429    async fn call_tool(
430        &self,
431        request: rmcp::model::CallToolRequestParams,
432        context: rmcp::service::RequestContext<rmcp::RoleServer>,
433    ) -> Result<CallToolResult, ErrorData> {
434        // Lookup the tool definition, erroring if not found or disabled
435        let Some(registered) = self.data.tools.get(&request.name[..]) else {
436            return Err(rmcp::model::ErrorData::invalid_params(
437                format!("tool `{}` not found", request.name),
438                None,
439            ));
440        };
441
442        // Treat disabled tools as not found
443        if !self.data.enabled_tools.is_enabled(&request.name) {
444            return Err(rmcp::model::ErrorData::invalid_params(
445                format!("tool `{}` not found", request.name),
446                None,
447            ));
448        }
449
450        // Convert input into JSON
451        let serde_value = serde_json::to_value(request.arguments).expect("valid json");
452
453        // Execute the user's tool, unless cancellation occurs
454        let has_structured_output = registered.has_structured_output;
455        match futures::future::select(
456            registered
457                .tool
458                .call_tool(serde_value, self.mcp_connection.clone()),
459            pin!(context.ct.cancelled()),
460        )
461        .await
462        {
463            // If completed successfully
464            Either::Left((m, _)) => match m {
465                Ok(result) => {
466                    // Use structured output only if the tool declared an output_schema
467                    if has_structured_output {
468                        Ok(CallToolResult::structured(result))
469                    } else {
470                        Ok(CallToolResult::success(vec![rmcp::model::Content::text(
471                            result.to_string(),
472                        )]))
473                    }
474                }
475                Err(error) => Err(to_rmcp_error(error)),
476            },
477
478            // If cancelled
479            Either::Right(((), _)) => {
480                Err(rmcp::ErrorData::internal_error("operation cancelled", None))
481            }
482        }
483    }
484
485    async fn list_tools(
486        &self,
487        _request: Option<rmcp::model::PaginatedRequestParams>,
488        _context: rmcp::service::RequestContext<rmcp::RoleServer>,
489    ) -> Result<rmcp::model::ListToolsResult, ErrorData> {
490        // Return only enabled tools
491        let tools: Vec<_> = self
492            .data
493            .tool_models
494            .iter()
495            .filter(|t| self.data.enabled_tools.is_enabled(&t.name))
496            .cloned()
497            .collect();
498        Ok(ListToolsResult::with_all_items(tools))
499    }
500
501    fn get_info(&self) -> rmcp::model::ServerInfo {
502        // Basic server info
503        let base = rmcp::model::ServerInfo::new(
504            rmcp::model::ServerCapabilities::builder()
505                .enable_tools()
506                .build(),
507        )
508        .with_server_info(rmcp::model::Implementation::default())
509        .with_protocol_version(rmcp::model::ProtocolVersion::default());
510
511        if let Some(instr) = self.data.instructions.clone() {
512            base.with_instructions(instr)
513        } else {
514            base
515        }
516    }
517}
518
519/// Erased version of the MCP tool trait that is dyn-compatible.
520trait ErasedMcpTool<Counterpart: Role>: Send + Sync {
521    fn call_tool(
522        &self,
523        input: serde_json::Value,
524        connection: McpConnectionTo<Counterpart>,
525    ) -> BoxFuture<'_, Result<serde_json::Value, crate::Error>>;
526}
527
528/// Create an `rmcp` tool model from our [`McpTool`] trait.
529fn make_tool_model<R: Role, M: McpTool<R>>(tool: &M) -> Tool {
530    let mut tool = rmcp::model::Tool::new(
531        tool.name(),
532        tool.description(),
533        schema_for_type::<M::Input>(),
534    )
535    .with_execution(rmcp::model::ToolExecution::new());
536
537    if let Ok(schema) = schema_for_output::<M::Output>() {
538        // schema_for_output returns Err for non-object types (strings, integers, etc.)
539        // since MCP structured output requires JSON objects. We set
540        // output_schema to None for these tools, signaling unstructured output.
541        tool = tool.with_raw_output_schema(schema);
542    }
543
544    tool
545}
546
547/// Create a [`ErasedMcpTool`] from a [`McpTool`], erasing the type details.
548fn make_erased_mcp_tool<'s, R: Role, M: McpTool<R> + 's>(
549    tool: M,
550) -> Arc<dyn ErasedMcpTool<R> + 's> {
551    struct ErasedMcpToolImpl<M> {
552        tool: M,
553    }
554
555    impl<R, M> ErasedMcpTool<R> for ErasedMcpToolImpl<M>
556    where
557        R: Role,
558        M: McpTool<R>,
559    {
560        fn call_tool(
561            &self,
562            input: serde_json::Value,
563            context: McpConnectionTo<R>,
564        ) -> BoxFuture<'_, Result<serde_json::Value, crate::Error>> {
565            Box::pin(async move {
566                let input = serde_json::from_value(input).map_err(crate::util::internal_error)?;
567                serde_json::to_value(self.tool.call_tool(input, context).await?)
568                    .map_err(crate::util::internal_error)
569            })
570        }
571    }
572
573    Arc::new(ErasedMcpToolImpl { tool })
574}
575
576/// Convert a [`crate::Error`] into an [`rmcp::ErrorData`].
577fn to_rmcp_error(error: crate::Error) -> rmcp::ErrorData {
578    rmcp::ErrorData {
579        code: rmcp::model::ErrorCode(error.code.into()),
580        message: error.message.into(),
581        data: error.data,
582    }
583}
584
585/// MCP tool used for `tool_fn` and `tooL_fn_mut`.
586/// Each time it is invoked, it sends a `ToolCall`  message to `call_tx`.
587struct ToolFnTool<P, Ret, R: Role> {
588    name: String,
589    description: String,
590    call_tx: mpsc::Sender<ToolCall<P, Ret, R>>,
591}
592
593impl<P, Ret, R> McpTool<R> for ToolFnTool<P, Ret, R>
594where
595    R: Role,
596    P: JsonSchema + DeserializeOwned + 'static + Send,
597    Ret: JsonSchema + Serialize + 'static + Send,
598{
599    type Input = P;
600    type Output = Ret;
601
602    fn name(&self) -> String {
603        self.name.clone()
604    }
605
606    fn description(&self) -> String {
607        self.description.clone()
608    }
609
610    async fn call_tool(
611        &self,
612        params: P,
613        mcp_connection: McpConnectionTo<R>,
614    ) -> Result<Ret, crate::Error> {
615        let (result_tx, result_rx) = oneshot::channel();
616
617        self.call_tx
618            .clone()
619            .send(ToolCall {
620                params,
621                mcp_connection,
622                result_tx,
623            })
624            .await
625            .map_err(crate::util::internal_error)?;
626
627        result_rx.await.map_err(crate::util::internal_error)?
628    }
629}