agent_client_protocol_rmcp/
builder.rs1use std::{marker::PhantomData, pin::pin, sync::Arc};
4
5use futures::future::{BoxFuture, Either};
6use futures_concurrency::future::TryJoin;
7use rmcp::{
8 ErrorData, ServerHandler,
9 model::{CallToolResult, ListToolsResult, Tool},
10};
11use schemars::JsonSchema;
12use serde::{Serialize, de::DeserializeOwned};
13use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
14
15use agent_client_protocol as acp;
16use agent_client_protocol::{
17 ByteStreams, ChainRun, ConnectTo, DynConnectTo, NullRun, RunWithConnectionTo,
18 mcp_server::{
19 McpConnectionTo, McpServer, McpServerConnect, McpTool, McpToolMetadata, McpToolRegistry,
20 },
21 role::{self, Role},
22};
23
24#[derive(Debug)]
47pub struct McpServerBuilder<Counterpart: Role, Responder>
48where
49 Responder: RunWithConnectionTo<Counterpart>,
50{
51 phantom: PhantomData<Counterpart>,
52 name: String,
53 data: McpToolRegistry<Counterpart>,
54 responder: Responder,
55}
56
57impl<Counterpart: Role> McpServerBuilder<Counterpart, NullRun> {
58 pub(super) fn new(name: String) -> Self {
59 Self {
60 name,
61 phantom: PhantomData,
62 data: McpToolRegistry::default(),
63 responder: NullRun,
64 }
65 }
66}
67
68impl<Counterpart: Role, Responder> McpServerBuilder<Counterpart, Responder>
69where
70 Responder: RunWithConnectionTo<Counterpart>,
71{
72 #[must_use]
74 pub fn instructions(mut self, instructions: impl ToString) -> Self {
75 self.data.set_instructions(instructions);
76 self
77 }
78
79 #[must_use]
81 pub fn tool(mut self, tool: impl McpTool<Counterpart> + 'static) -> Self {
82 self.data.register_tool(tool);
83 self
84 }
85
86 #[must_use]
89 pub fn disable_all_tools(mut self) -> Self {
90 self.data.disable_all_tools();
91 self
92 }
93
94 #[must_use]
97 pub fn enable_all_tools(mut self) -> Self {
98 self.data.enable_all_tools();
99 self
100 }
101
102 pub fn disable_tool(mut self, name: &str) -> Result<Self, acp::Error> {
106 self.data.disable_tool(name)?;
107 Ok(self)
108 }
109
110 pub fn enable_tool(mut self, name: &str) -> Result<Self, acp::Error> {
114 self.data.enable_tool(name)?;
115 Ok(self)
116 }
117
118 fn tool_with_responder(
121 self,
122 tool: impl McpTool<Counterpart> + 'static,
123 tool_responder: impl RunWithConnectionTo<Counterpart>,
124 ) -> McpServerBuilder<Counterpart, impl RunWithConnectionTo<Counterpart>> {
125 let this = self.tool(tool);
126 McpServerBuilder {
127 phantom: PhantomData,
128 name: this.name,
129 data: this.data,
130 responder: ChainRun::new(this.responder, tool_responder),
131 }
132 }
133
134 pub fn tool_fn_mut<P, Ret, F>(
157 self,
158 name: impl ToString,
159 description: impl ToString,
160 func: F,
161 tool_future_hack: impl for<'a> Fn(
162 &'a mut F,
163 P,
164 McpConnectionTo<Counterpart>,
165 ) -> BoxFuture<'a, Result<Ret, acp::Error>>
166 + Send
167 + 'static,
168 ) -> McpServerBuilder<Counterpart, impl RunWithConnectionTo<Counterpart>>
169 where
170 P: JsonSchema + DeserializeOwned + 'static + Send,
171 Ret: JsonSchema + Serialize + 'static + Send,
172 F: AsyncFnMut(P, McpConnectionTo<Counterpart>) -> Result<Ret, acp::Error> + Send,
173 {
174 let (tool, responder) =
175 acp::mcp_server::tool_fn_mut(name, description, func, tool_future_hack);
176 self.tool_with_responder(tool, responder)
177 }
178
179 pub fn tool_fn<P, Ret, F>(
200 self,
201 name: impl ToString,
202 description: impl ToString,
203 func: F,
204 tool_future_hack: impl for<'a> Fn(
205 &'a F,
206 P,
207 McpConnectionTo<Counterpart>,
208 ) -> BoxFuture<'a, Result<Ret, acp::Error>>
209 + Send
210 + Sync
211 + 'static,
212 ) -> McpServerBuilder<Counterpart, impl RunWithConnectionTo<Counterpart>>
213 where
214 P: JsonSchema + DeserializeOwned + 'static + Send,
215 Ret: JsonSchema + Serialize + 'static + Send,
216 F: AsyncFn(P, McpConnectionTo<Counterpart>) -> Result<Ret, acp::Error>
217 + Send
218 + Sync
219 + 'static,
220 {
221 let (tool, responder) = acp::mcp_server::tool_fn(name, description, func, tool_future_hack);
222 self.tool_with_responder(tool, responder)
223 }
224
225 pub fn build(self) -> McpServer<Counterpart, Responder> {
230 McpServer::new(
231 McpServerBuilt {
232 name: self.name,
233 data: Arc::new(self.data),
234 },
235 self.responder,
236 )
237 }
238}
239
240struct McpServerBuilt<Counterpart: Role> {
241 name: String,
242 data: Arc<McpToolRegistry<Counterpart>>,
243}
244
245impl<Counterpart: Role> McpServerConnect<Counterpart> for McpServerBuilt<Counterpart> {
246 fn name(&self) -> String {
247 self.name.clone()
248 }
249
250 fn connect(
251 &self,
252 mcp_connection: McpConnectionTo<Counterpart>,
253 ) -> DynConnectTo<role::mcp::Client> {
254 DynConnectTo::new(McpServerConnection {
255 data: self.data.clone(),
256 mcp_connection,
257 })
258 }
259}
260
261pub(crate) struct McpServerConnection<Counterpart: Role> {
263 data: Arc<McpToolRegistry<Counterpart>>,
264 mcp_connection: McpConnectionTo<Counterpart>,
265}
266
267impl<Counterpart: Role> ConnectTo<role::mcp::Client> for McpServerConnection<Counterpart> {
268 async fn connect_to(self, client: impl ConnectTo<role::mcp::Server>) -> Result<(), acp::Error> {
269 let (mcp_server_stream, mcp_client_stream) = tokio::io::duplex(8192);
271 let (mcp_server_read, mcp_server_write) = tokio::io::split(mcp_server_stream);
272 let (mcp_client_read, mcp_client_write) = tokio::io::split(mcp_client_stream);
273
274 let run_client = async {
275 let byte_streams =
276 ByteStreams::new(mcp_client_write.compat_write(), mcp_client_read.compat());
277 <ByteStreams<_, _> as ConnectTo<role::mcp::Client>>::connect_to(byte_streams, client)
278 .await
279 };
280
281 let run_server = async {
282 let running_server = rmcp::ServiceExt::serve(self, (mcp_server_read, mcp_server_write))
284 .await
285 .map_err(acp::Error::into_internal_error)?;
286
287 running_server
289 .waiting()
290 .await
291 .map(|_quit_reason| ())
292 .map_err(acp::Error::into_internal_error)
293 };
294
295 (run_client, run_server).try_join().await?;
296 Ok(())
297 }
298}
299
300impl<R: Role> ServerHandler for McpServerConnection<R> {
301 async fn call_tool(
302 &self,
303 request: rmcp::model::CallToolRequestParams,
304 context: rmcp::service::RequestContext<rmcp::RoleServer>,
305 ) -> Result<CallToolResult, ErrorData> {
306 let Some(registered) = self.data.enabled_tool(&request.name) else {
308 return Err(rmcp::model::ErrorData::invalid_params(
309 format!("tool `{}` not found", request.name),
310 None,
311 ));
312 };
313
314 let serde_value = serde_json::to_value(request.arguments).expect("valid json");
316
317 let has_structured_output = registered.has_structured_output();
319 match futures::future::select(
320 registered.call_tool(serde_value, self.mcp_connection.clone()),
321 pin!(context.ct.cancelled()),
322 )
323 .await
324 {
325 Either::Left((m, _)) => match m {
327 Ok(result) => {
328 if has_structured_output {
330 Ok(CallToolResult::structured(result))
331 } else {
332 Ok(CallToolResult::success(vec![rmcp::model::Content::text(
333 result.to_string(),
334 )]))
335 }
336 }
337 Err(error) => Err(to_rmcp_error(error)),
338 },
339
340 Either::Right(((), _)) => {
342 Err(rmcp::ErrorData::internal_error("operation cancelled", None))
343 }
344 }
345 }
346
347 async fn list_tools(
348 &self,
349 _request: Option<rmcp::model::PaginatedRequestParams>,
350 _context: rmcp::service::RequestContext<rmcp::RoleServer>,
351 ) -> Result<rmcp::model::ListToolsResult, ErrorData> {
352 let tools: Vec<_> = self
354 .data
355 .enabled_tools()
356 .map(|tool| make_tool_model(tool.metadata()))
357 .collect();
358 Ok(ListToolsResult::with_all_items(tools))
359 }
360
361 fn get_info(&self) -> rmcp::model::ServerInfo {
362 let base = rmcp::model::ServerInfo::new(
364 rmcp::model::ServerCapabilities::builder()
365 .enable_tools()
366 .build(),
367 )
368 .with_server_info(rmcp::model::Implementation::default())
369 .with_protocol_version(rmcp::model::ProtocolVersion::default());
370
371 if let Some(instructions) = self.data.instructions() {
372 base.with_instructions(instructions.to_string())
373 } else {
374 base
375 }
376 }
377}
378
379fn make_tool_model(metadata: &McpToolMetadata) -> Tool {
381 let mut tool = rmcp::model::Tool::new(
382 metadata.name().to_string(),
383 metadata.description().to_string(),
384 metadata.input_schema().clone(),
385 )
386 .with_execution(rmcp::model::ToolExecution::new());
387
388 if let Some(title) = metadata.title() {
389 tool = tool.with_title(title.to_string());
390 }
391
392 if let Some(schema) = metadata.output_schema() {
393 tool = tool.with_raw_output_schema(schema.clone());
394 }
395
396 tool
397}
398
399fn to_rmcp_error(error: acp::Error) -> rmcp::ErrorData {
401 rmcp::ErrorData {
402 code: rmcp::model::ErrorCode(error.code.into()),
403 message: error.message.into(),
404 data: error.data,
405 }
406}