1use super::{
2 content::convert_mcp_content, MCPInit, MCPParams, MCPStdioParams, MCPStreamableHTTPParams,
3};
4use crate::{
5 errors::BoxedError,
6 tool::{AgentTool, AgentToolResult},
7 toolkit::{Toolkit, ToolkitSession},
8 RunState,
9};
10use futures::future::BoxFuture;
11use llm_sdk;
12use rmcp::{
13 handler::client::ClientHandler,
14 model::{CallToolRequestParam, CallToolResult, Tool},
15 service::{serve_client, NotificationContext, RoleClient, RunningService},
16 transport::{
17 child_process::TokioChildProcess,
18 streamable_http_client::{
19 StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
20 },
21 },
22};
23use serde_json::Value;
24use std::{
25 io::{Error as IoError, ErrorKind},
26 sync::{Arc, OnceLock, RwLock, Weak},
27};
28use tokio::process::Command;
29
30type MCPRunningService<TCtx> = RunningService<RoleClient, MCPToolkitState<TCtx>>;
31
32pub struct MCPToolkit<TCtx>
34where
35 TCtx: Send + Sync + 'static,
36{
37 init: MCPInit<TCtx>,
38}
39
40impl<TCtx> MCPToolkit<TCtx>
41where
42 TCtx: Send + Sync + 'static,
43{
44 pub fn new(init: impl Into<MCPInit<TCtx>>) -> Self {
45 Self { init: init.into() }
46 }
47}
48
49impl<TCtx> Toolkit<TCtx> for MCPToolkit<TCtx>
50where
51 TCtx: Send + Sync + 'static,
52{
53 fn create_session<'a>(
54 &'a self,
55 context: &'a TCtx,
56 ) -> BoxFuture<'a, Result<Box<dyn ToolkitSession<TCtx> + Send + Sync>, BoxedError>> {
57 Box::pin(async move {
58 let params = self.init.resolve(context).await?;
59 let session = MCPToolkitSession::new(params).await?;
60 let boxed: Box<dyn ToolkitSession<TCtx> + Send + Sync> = Box::new(session);
61 Ok(boxed)
62 })
63 }
64}
65
66struct MCPToolkitSession<TCtx>
68where
69 TCtx: Send + Sync + 'static,
70{
71 service: Arc<MCPRunningService<TCtx>>,
72 state: MCPToolkitState<TCtx>,
73}
74
75impl<TCtx> MCPToolkitSession<TCtx>
76where
77 TCtx: Send + Sync + 'static,
78{
79 async fn new(params: MCPParams) -> Result<Self, BoxedError> {
80 let state = MCPToolkitState::new();
81 let handler = state.clone();
82
83 let service = match params {
84 MCPParams::Stdio(MCPStdioParams { command, args }) => {
85 let mut cmd = Command::new(command);
86 cmd.args(args);
87 let transport = TokioChildProcess::new(cmd)?;
88 serve_client(handler, transport).await?
89 }
90 MCPParams::StreamableHttp(MCPStreamableHTTPParams { url, authorization }) => {
91 let mut config = StreamableHttpClientTransportConfig::with_uri(url.clone());
92 if let Some(token) = authorization.as_deref() {
93 config = config.auth_header(strip_bearer_prefix(token));
94 }
95 let transport = StreamableHttpClientTransport::from_config(config);
96 serve_client(handler, transport).await?
97 }
98 };
99
100 let service = Arc::new(service);
101 state.register_service(&service);
102 state.refresh_with(&service).await?;
103
104 Ok(Self { service, state })
105 }
106}
107
108impl<TCtx> ToolkitSession<TCtx> for MCPToolkitSession<TCtx>
109where
110 TCtx: Send + Sync + 'static,
111{
112 fn system_prompt(&self) -> Option<String> {
113 None
114 }
115
116 fn tools(&self) -> Vec<Arc<dyn AgentTool<TCtx>>> {
117 self.state.tools()
118 }
119
120 fn close(self: Box<Self>) -> BoxFuture<'static, Result<(), BoxedError>> {
121 Box::pin(async move {
122 match Arc::try_unwrap(self.service) {
123 Ok(service) => {
124 let _ = service.cancel().await;
125 }
126 Err(arc) => {
127 arc.cancellation_token().cancel();
128 }
129 }
130 Ok(())
131 })
132 }
133}
134
135struct MCPRemoteTool<TCtx>
136where
137 TCtx: Send + Sync + 'static,
138{
139 service: Weak<MCPRunningService<TCtx>>,
140 name: String,
141 description: String,
142 parameters: llm_sdk::JSONSchema,
143}
144
145impl<TCtx> MCPRemoteTool<TCtx>
146where
147 TCtx: Send + Sync + 'static,
148{
149 fn new(service: &Arc<MCPRunningService<TCtx>>, tool: Tool) -> Self {
150 let parameters =
151 Value::Object(Arc::try_unwrap(tool.input_schema).unwrap_or_else(|arc| (*arc).clone()));
152 let description = tool
153 .description
154 .map(std::borrow::Cow::into_owned)
155 .unwrap_or_default();
156 Self {
157 service: Arc::downgrade(service),
158 name: tool.name.into_owned(),
159 description,
160 parameters,
161 }
162 }
163}
164
165impl<TCtx> AgentTool<TCtx> for MCPRemoteTool<TCtx>
166where
167 TCtx: Send + Sync + 'static,
168{
169 fn name(&self) -> String {
170 self.name.clone()
171 }
172
173 fn description(&self) -> String {
174 self.description.clone()
175 }
176
177 fn parameters(&self) -> llm_sdk::JSONSchema {
178 self.parameters.clone()
179 }
180
181 fn execute(
182 &self,
183 args: Value,
184 _context: &TCtx,
185 _state: &RunState,
186 ) -> BoxFuture<'_, Result<AgentToolResult, BoxedError>> {
187 Box::pin(async move {
188 let arguments = match args {
189 Value::Null => None,
190 Value::Object(map) => Some(map),
191 other => {
192 let message = format!("MCP tool arguments must be an object, received {other}");
193 return Err(
194 Box::new(IoError::new(ErrorKind::InvalidInput, message)) as BoxedError
195 );
196 }
197 };
198
199 let request = CallToolRequestParam {
200 name: self.name.clone().into(),
201 arguments,
202 };
203
204 let Some(service) = self.service.upgrade() else {
205 return Err(Box::new(IoError::new(
206 ErrorKind::NotConnected,
207 "MCP service not initialised",
208 )) as BoxedError);
209 };
210 let result = service
211 .call_tool(request)
212 .await
213 .map_err(|err| Box::new(err) as BoxedError)?;
214
215 let CallToolResult {
216 content, is_error, ..
217 } = result;
218
219 let content = convert_mcp_content(content)?;
220 let is_error = is_error.unwrap_or(false);
221
222 Ok(AgentToolResult { content, is_error })
223 })
224 }
225}
226
227fn strip_bearer_prefix(token: &str) -> String {
230 let trimmed = token.trim();
231 if let Some(rest) = trimmed.strip_prefix("Bearer ") {
232 rest.to_string()
233 } else if let Some(rest) = trimmed.strip_prefix("bearer ") {
234 rest.to_string()
235 } else {
236 trimmed.to_string()
237 }
238}
239
240#[allow(clippy::type_complexity)]
241struct MCPToolkitState<TCtx>
242where
243 TCtx: Send + Sync + 'static,
244{
245 service: Arc<OnceLock<Weak<MCPRunningService<TCtx>>>>,
246 tools: Arc<RwLock<Result<Vec<Arc<dyn AgentTool<TCtx>>>, String>>>,
247}
248
249impl<TCtx> Clone for MCPToolkitState<TCtx>
250where
251 TCtx: Send + Sync + 'static,
252{
253 fn clone(&self) -> Self {
254 Self {
255 service: Arc::clone(&self.service),
256 tools: Arc::clone(&self.tools),
257 }
258 }
259}
260
261impl<TCtx> MCPToolkitState<TCtx>
262where
263 TCtx: Send + Sync + 'static,
264{
265 fn new() -> Self {
266 Self {
267 service: Arc::new(OnceLock::new()),
268 tools: Arc::new(RwLock::new(Ok(Vec::new()))),
269 }
270 }
271
272 fn register_service(&self, service: &Arc<MCPRunningService<TCtx>>) {
273 let _ = self.service.set(Arc::downgrade(service));
274 }
275
276 async fn refresh(&self) -> Result<(), BoxedError> {
277 let service = self.service()?;
278 self.refresh_with(&service).await
279 }
280
281 async fn refresh_with(&self, service: &Arc<MCPRunningService<TCtx>>) -> Result<(), BoxedError> {
282 let specs = service
283 .peer()
284 .list_all_tools()
285 .await
286 .map_err(|err| Box::new(err) as BoxedError)?;
287
288 let mut new_tools: Vec<Arc<dyn AgentTool<TCtx>>> = Vec::with_capacity(specs.len());
289 for spec in specs {
290 let remote = MCPRemoteTool::new(service, spec);
291 new_tools.push(Arc::new(remote));
292 }
293
294 let mut guard = self.tools.write().expect("tool registry lock poisoned");
295 *guard = Ok(new_tools);
296 Ok(())
297 }
298
299 fn tools(&self) -> Vec<Arc<dyn AgentTool<TCtx>>> {
300 let guard = self.tools.read().expect("tool registry lock poisoned");
301 match guard.as_ref() {
302 Ok(tools) => tools.clone(),
303 Err(message) => panic!("mcp tool discovery failed: {message}"),
304 }
305 }
306
307 fn record_error<E>(&self, err: E)
308 where
309 E: std::fmt::Display,
310 {
311 if let Ok(mut guard) = self.tools.write() {
312 *guard = Err(err.to_string());
313 }
314 }
315
316 fn service(&self) -> Result<Arc<MCPRunningService<TCtx>>, BoxedError> {
317 self.service
318 .get()
319 .and_then(Weak::upgrade)
320 .ok_or_else(|| -> BoxedError {
321 Box::new(IoError::new(
322 ErrorKind::NotConnected,
323 "MCP service not initialised",
324 ))
325 })
326 }
327}
328
329impl<TCtx> ClientHandler for MCPToolkitState<TCtx>
330where
331 TCtx: Send + Sync + 'static,
332{
333 fn on_tool_list_changed(
334 &self,
335 _context: NotificationContext<RoleClient>,
336 ) -> impl std::future::Future<Output = ()> + Send + '_ {
337 let state = self.clone();
338 async move {
339 if let Err(err) = state.refresh().await {
340 state.record_error(err);
341 }
342 }
343 }
344}