Skip to main content

agent_client_protocol_rmcp/
builder.rs

1//! MCP server builder for creating MCP servers.
2
3use std::{marker::PhantomData, pin::pin, sync::Arc};
4
5use futures::future::{BoxFuture, Either};
6use futures_concurrency::future::TryJoin;
7use rmcp::{
8    ErrorData, ServerHandler,
9    model::{CallToolResult, ListToolsResult, Tool},
10};
11use schemars::JsonSchema;
12use serde::{Serialize, de::DeserializeOwned};
13use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
14
15use agent_client_protocol as acp;
16use agent_client_protocol::{
17    ByteStreams, ChainRun, ConnectTo, DynConnectTo, NullRun, RunWithConnectionTo,
18    mcp_server::{
19        McpConnectionTo, McpServer, McpServerConnect, McpTool, McpToolMetadata, McpToolRegistry,
20    },
21    role::{self, Role},
22};
23
24/// Builder for creating MCP servers with tools.
25///
26/// Use [`crate::McpServerExt::builder`] to create a new builder, then chain methods to
27/// configure the server and call [`build`](Self::build) to create the server.
28///
29/// # Example
30///
31/// ```rust,ignore
32/// use agent_client_protocol::mcp_server::McpServer;
33/// use agent_client_protocol_rmcp::McpServerExt;
34///
35/// let server = McpServer::builder("my-server".to_string())
36///     .instructions("A helpful assistant")
37///     .tool(EchoTool)
38///     .tool_fn(
39///         "greet",
40///         "Greet someone by name",
41///         async |input: GreetInput, _cx| Ok(format!("Hello, {}!", input.name)),
42///         agent_client_protocol_rmcp::tool_fn!(),
43///     )
44///     .build();
45/// ```
46#[derive(Debug)]
47pub struct McpServerBuilder<Counterpart: Role, Responder>
48where
49    Responder: RunWithConnectionTo<Counterpart>,
50{
51    phantom: PhantomData<Counterpart>,
52    name: String,
53    data: McpToolRegistry<Counterpart>,
54    responder: Responder,
55}
56
57impl<Counterpart: Role> McpServerBuilder<Counterpart, NullRun> {
58    pub(super) fn new(name: String) -> Self {
59        Self {
60            name,
61            phantom: PhantomData,
62            data: McpToolRegistry::default(),
63            responder: NullRun,
64        }
65    }
66}
67
68impl<Counterpart: Role, Responder> McpServerBuilder<Counterpart, Responder>
69where
70    Responder: RunWithConnectionTo<Counterpart>,
71{
72    /// Set the server instructions that are provided to the client.
73    #[must_use]
74    pub fn instructions(mut self, instructions: impl ToString) -> Self {
75        self.data.set_instructions(instructions);
76        self
77    }
78
79    /// Add a tool to the server.
80    #[must_use]
81    pub fn tool(mut self, tool: impl McpTool<Counterpart> + 'static) -> Self {
82        self.data.register_tool(tool);
83        self
84    }
85
86    /// Disable all tools. After calling this, only tools explicitly enabled
87    /// with [`enable_tool`](Self::enable_tool) will be available.
88    #[must_use]
89    pub fn disable_all_tools(mut self) -> Self {
90        self.data.disable_all_tools();
91        self
92    }
93
94    /// Enable all tools. After calling this, all tools will be available
95    /// except those explicitly disabled with [`disable_tool`](Self::disable_tool).
96    #[must_use]
97    pub fn enable_all_tools(mut self) -> Self {
98        self.data.enable_all_tools();
99        self
100    }
101
102    /// Disable a specific tool by name.
103    ///
104    /// Returns an error if the tool is not registered.
105    pub fn disable_tool(mut self, name: &str) -> Result<Self, acp::Error> {
106        self.data.disable_tool(name)?;
107        Ok(self)
108    }
109
110    /// Enable a specific tool by name.
111    ///
112    /// Returns an error if the tool is not registered.
113    pub fn enable_tool(mut self, name: &str) -> Result<Self, acp::Error> {
114        self.data.enable_tool(name)?;
115        Ok(self)
116    }
117
118    /// Private fn: adds the tool but also adds a responder that will be
119    /// run while the MCP server is active.
120    fn tool_with_responder(
121        self,
122        tool: impl McpTool<Counterpart> + 'static,
123        tool_responder: impl RunWithConnectionTo<Counterpart>,
124    ) -> McpServerBuilder<Counterpart, impl RunWithConnectionTo<Counterpart>> {
125        let this = self.tool(tool);
126        McpServerBuilder {
127            phantom: PhantomData,
128            name: this.name,
129            data: this.data,
130            responder: ChainRun::new(this.responder, tool_responder),
131        }
132    }
133
134    /// Convenience wrapper for defining a "single-threaded" tool without having to create a struct.
135    /// By "single-threaded", we mean that only one invocation of the tool can be running at a time.
136    /// Typically agents invoke a tool once per session and then block waiting for the result,
137    /// so this is fine, but they could attempt to run multiple invocations concurrently, in which
138    /// case those invocations would be serialized.
139    ///
140    /// # Parameters
141    ///
142    /// * `name`: The name of the tool.
143    /// * `description`: The description of the tool.
144    /// * `func`: The function that implements the tool. Use an async closure like `async |args, cx| { .. }`.
145    ///
146    /// # Examples
147    ///
148    /// ```rust,ignore
149    /// McpServer::builder("my-server")
150    ///     .tool_fn_mut(
151    ///         "greet",
152    ///         "Greet someone by name",
153    ///         async |input: GreetInput, _cx| Ok(format!("Hello, {}!", input.name)),
154    ///     )
155    /// ```
156    pub fn tool_fn_mut<P, Ret, F>(
157        self,
158        name: impl ToString,
159        description: impl ToString,
160        func: F,
161        tool_future_hack: impl for<'a> Fn(
162            &'a mut F,
163            P,
164            McpConnectionTo<Counterpart>,
165        ) -> BoxFuture<'a, Result<Ret, acp::Error>>
166        + Send
167        + 'static,
168    ) -> McpServerBuilder<Counterpart, impl RunWithConnectionTo<Counterpart>>
169    where
170        P: JsonSchema + DeserializeOwned + 'static + Send,
171        Ret: JsonSchema + Serialize + 'static + Send,
172        F: AsyncFnMut(P, McpConnectionTo<Counterpart>) -> Result<Ret, acp::Error> + Send,
173    {
174        let (tool, responder) =
175            acp::mcp_server::tool_fn_mut(name, description, func, tool_future_hack);
176        self.tool_with_responder(tool, responder)
177    }
178
179    /// Convenience wrapper for defining a stateless tool that can run concurrently.
180    /// Unlike [`tool_fn_mut`](Self::tool_fn_mut), multiple invocations of this tool can run
181    /// at the same time since the function is `Fn` rather than `FnMut`.
182    ///
183    /// # Parameters
184    ///
185    /// * `name`: The name of the tool.
186    /// * `description`: The description of the tool.
187    /// * `func`: The function that implements the tool. Use an async closure like `async |args, cx| { .. }`.
188    ///
189    /// # Examples
190    ///
191    /// ```rust,ignore
192    /// McpServer::builder("my-server")
193    ///     .tool_fn(
194    ///         "greet",
195    ///         "Greet someone by name",
196    ///         async |input: GreetInput, _cx| Ok(format!("Hello, {}!", input.name)),
197    ///     )
198    /// ```
199    pub fn tool_fn<P, Ret, F>(
200        self,
201        name: impl ToString,
202        description: impl ToString,
203        func: F,
204        tool_future_hack: impl for<'a> Fn(
205            &'a F,
206            P,
207            McpConnectionTo<Counterpart>,
208        ) -> BoxFuture<'a, Result<Ret, acp::Error>>
209        + Send
210        + Sync
211        + 'static,
212    ) -> McpServerBuilder<Counterpart, impl RunWithConnectionTo<Counterpart>>
213    where
214        P: JsonSchema + DeserializeOwned + 'static + Send,
215        Ret: JsonSchema + Serialize + 'static + Send,
216        F: AsyncFn(P, McpConnectionTo<Counterpart>) -> Result<Ret, acp::Error>
217            + Send
218            + Sync
219            + 'static,
220    {
221        let (tool, responder) = acp::mcp_server::tool_fn(name, description, func, tool_future_hack);
222        self.tool_with_responder(tool, responder)
223    }
224
225    /// Create an MCP server from this builder.
226    ///
227    /// This builder can be attached to new sessions (see [`SessionBuilder::with_mcp_server`](`agent_client_protocol::SessionBuilder::with_mcp_server`))
228    /// or served up as part of a proxy (see [`Builder::with_mcp_server`](`agent_client_protocol::Builder::with_mcp_server`)).
229    pub fn build(self) -> McpServer<Counterpart, Responder> {
230        McpServer::new(
231            McpServerBuilt {
232                name: self.name,
233                data: Arc::new(self.data),
234            },
235            self.responder,
236        )
237    }
238}
239
240struct McpServerBuilt<Counterpart: Role> {
241    name: String,
242    data: Arc<McpToolRegistry<Counterpart>>,
243}
244
245impl<Counterpart: Role> McpServerConnect<Counterpart> for McpServerBuilt<Counterpart> {
246    fn name(&self) -> String {
247        self.name.clone()
248    }
249
250    fn connect(
251        &self,
252        mcp_connection: McpConnectionTo<Counterpart>,
253    ) -> DynConnectTo<role::mcp::Client> {
254        DynConnectTo::new(McpServerConnection {
255            data: self.data.clone(),
256            mcp_connection,
257        })
258    }
259}
260
261/// An MCP server instance connected to the ACP framework.
262pub(crate) struct McpServerConnection<Counterpart: Role> {
263    data: Arc<McpToolRegistry<Counterpart>>,
264    mcp_connection: McpConnectionTo<Counterpart>,
265}
266
267impl<Counterpart: Role> ConnectTo<role::mcp::Client> for McpServerConnection<Counterpart> {
268    async fn connect_to(self, client: impl ConnectTo<role::mcp::Server>) -> Result<(), acp::Error> {
269        // Create tokio byte streams that rmcp expects
270        let (mcp_server_stream, mcp_client_stream) = tokio::io::duplex(8192);
271        let (mcp_server_read, mcp_server_write) = tokio::io::split(mcp_server_stream);
272        let (mcp_client_read, mcp_client_write) = tokio::io::split(mcp_client_stream);
273
274        let run_client = async {
275            let byte_streams =
276                ByteStreams::new(mcp_client_write.compat_write(), mcp_client_read.compat());
277            <ByteStreams<_, _> as ConnectTo<role::mcp::Client>>::connect_to(byte_streams, client)
278                .await
279        };
280
281        let run_server = async {
282            // Run the rmcp server with the server side of the duplex stream
283            let running_server = rmcp::ServiceExt::serve(self, (mcp_server_read, mcp_server_write))
284                .await
285                .map_err(acp::Error::into_internal_error)?;
286
287            // Wait for the server to finish
288            running_server
289                .waiting()
290                .await
291                .map(|_quit_reason| ())
292                .map_err(acp::Error::into_internal_error)
293        };
294
295        (run_client, run_server).try_join().await?;
296        Ok(())
297    }
298}
299
300impl<R: Role> ServerHandler for McpServerConnection<R> {
301    async fn call_tool(
302        &self,
303        request: rmcp::model::CallToolRequestParams,
304        context: rmcp::service::RequestContext<rmcp::RoleServer>,
305    ) -> Result<CallToolResult, ErrorData> {
306        // Lookup the tool definition, erroring if not found or disabled
307        let Some(registered) = self.data.enabled_tool(&request.name) else {
308            return Err(rmcp::model::ErrorData::invalid_params(
309                format!("tool `{}` not found", request.name),
310                None,
311            ));
312        };
313
314        // Convert input into JSON
315        let serde_value = serde_json::to_value(request.arguments).expect("valid json");
316
317        // Execute the user's tool, unless cancellation occurs
318        let has_structured_output = registered.has_structured_output();
319        match futures::future::select(
320            registered.call_tool(serde_value, self.mcp_connection.clone()),
321            pin!(context.ct.cancelled()),
322        )
323        .await
324        {
325            // If completed successfully
326            Either::Left((m, _)) => match m {
327                Ok(result) => {
328                    // Use structured output only if the tool declared an output_schema
329                    if has_structured_output {
330                        Ok(CallToolResult::structured(result))
331                    } else {
332                        Ok(CallToolResult::success(vec![rmcp::model::Content::text(
333                            result.to_string(),
334                        )]))
335                    }
336                }
337                Err(error) => Err(to_rmcp_error(error)),
338            },
339
340            // If cancelled
341            Either::Right(((), _)) => {
342                Err(rmcp::ErrorData::internal_error("operation cancelled", None))
343            }
344        }
345    }
346
347    async fn list_tools(
348        &self,
349        _request: Option<rmcp::model::PaginatedRequestParams>,
350        _context: rmcp::service::RequestContext<rmcp::RoleServer>,
351    ) -> Result<rmcp::model::ListToolsResult, ErrorData> {
352        // Return only enabled tools
353        let tools: Vec<_> = self
354            .data
355            .enabled_tools()
356            .map(|tool| make_tool_model(tool.metadata()))
357            .collect();
358        Ok(ListToolsResult::with_all_items(tools))
359    }
360
361    fn get_info(&self) -> rmcp::model::ServerInfo {
362        // Basic server info
363        let base = rmcp::model::ServerInfo::new(
364            rmcp::model::ServerCapabilities::builder()
365                .enable_tools()
366                .build(),
367        )
368        .with_server_info(rmcp::model::Implementation::default())
369        .with_protocol_version(rmcp::model::ProtocolVersion::default());
370
371        if let Some(instructions) = self.data.instructions() {
372            base.with_instructions(instructions.to_string())
373        } else {
374            base
375        }
376    }
377}
378
379/// Create an `rmcp` tool model from runtime-neutral MCP tool metadata.
380fn make_tool_model(metadata: &McpToolMetadata) -> Tool {
381    let mut tool = rmcp::model::Tool::new(
382        metadata.name().to_string(),
383        metadata.description().to_string(),
384        metadata.input_schema().clone(),
385    )
386    .with_execution(rmcp::model::ToolExecution::new());
387
388    if let Some(title) = metadata.title() {
389        tool = tool.with_title(title.to_string());
390    }
391
392    if let Some(schema) = metadata.output_schema() {
393        tool = tool.with_raw_output_schema(schema.clone());
394    }
395
396    tool
397}
398
399/// Convert an [`agent_client_protocol::Error`] into an [`rmcp::ErrorData`].
400fn to_rmcp_error(error: acp::Error) -> rmcp::ErrorData {
401    rmcp::ErrorData {
402        code: rmcp::model::ErrorCode(error.code.into()),
403        message: error.message.into(),
404        data: error.data,
405    }
406}