llm_agent/mcp/
toolkit.rs

1use super::{
2    content::convert_mcp_content, MCPInit, MCPParams, MCPStdioParams, MCPStreamableHTTPParams,
3};
4use crate::{
5    errors::BoxedError,
6    tool::{AgentTool, AgentToolResult},
7    toolkit::{Toolkit, ToolkitSession},
8    RunState,
9};
10use futures::future::BoxFuture;
11use llm_sdk;
12use rmcp::{
13    handler::client::ClientHandler,
14    model::{CallToolRequestParam, CallToolResult, Tool},
15    service::{serve_client, NotificationContext, RoleClient, RunningService},
16    transport::{
17        child_process::TokioChildProcess,
18        streamable_http_client::{
19            StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
20        },
21    },
22};
23use serde_json::Value;
24use std::{
25    io::{Error as IoError, ErrorKind},
26    sync::{Arc, OnceLock, RwLock, Weak},
27};
28use tokio::process::Command;
29
30type MCPRunningService<TCtx> = RunningService<RoleClient, MCPToolkitState<TCtx>>;
31
32/// Toolkit implementation backed by the Model Context Protocol.
33pub struct MCPToolkit<TCtx>
34where
35    TCtx: Send + Sync + 'static,
36{
37    init: MCPInit<TCtx>,
38}
39
40impl<TCtx> MCPToolkit<TCtx>
41where
42    TCtx: Send + Sync + 'static,
43{
44    pub fn new(init: impl Into<MCPInit<TCtx>>) -> Self {
45        Self { init: init.into() }
46    }
47}
48
49impl<TCtx> Toolkit<TCtx> for MCPToolkit<TCtx>
50where
51    TCtx: Send + Sync + 'static,
52{
53    fn create_session<'a>(
54        &'a self,
55        context: &'a TCtx,
56    ) -> BoxFuture<'a, Result<Box<dyn ToolkitSession<TCtx> + Send + Sync>, BoxedError>> {
57        Box::pin(async move {
58            let params = self.init.resolve(context).await?;
59            let session = MCPToolkitSession::new(params).await?;
60            let boxed: Box<dyn ToolkitSession<TCtx> + Send + Sync> = Box::new(session);
61            Ok(boxed)
62        })
63    }
64}
65
66/// `ToolkitSession` implementation that exposes MCP tools to the agent runtime.
67struct MCPToolkitSession<TCtx>
68where
69    TCtx: Send + Sync + 'static,
70{
71    service: Arc<MCPRunningService<TCtx>>,
72    state: MCPToolkitState<TCtx>,
73}
74
75impl<TCtx> MCPToolkitSession<TCtx>
76where
77    TCtx: Send + Sync + 'static,
78{
79    async fn new(params: MCPParams) -> Result<Self, BoxedError> {
80        let state = MCPToolkitState::new();
81        let handler = state.clone();
82
83        let service = match params {
84            MCPParams::Stdio(MCPStdioParams { command, args }) => {
85                let mut cmd = Command::new(command);
86                cmd.args(args);
87                let transport = TokioChildProcess::new(cmd)?;
88                serve_client(handler, transport).await?
89            }
90            MCPParams::StreamableHttp(MCPStreamableHTTPParams { url, authorization }) => {
91                let mut config = StreamableHttpClientTransportConfig::with_uri(url.clone());
92                if let Some(token) = authorization.as_deref() {
93                    config = config.auth_header(strip_bearer_prefix(token));
94                }
95                let transport = StreamableHttpClientTransport::from_config(config);
96                serve_client(handler, transport).await?
97            }
98        };
99
100        let service = Arc::new(service);
101        state.register_service(&service);
102        state.refresh_with(&service).await?;
103
104        Ok(Self { service, state })
105    }
106}
107
108impl<TCtx> ToolkitSession<TCtx> for MCPToolkitSession<TCtx>
109where
110    TCtx: Send + Sync + 'static,
111{
112    fn system_prompt(&self) -> Option<String> {
113        None
114    }
115
116    fn tools(&self) -> Vec<Arc<dyn AgentTool<TCtx>>> {
117        self.state.tools()
118    }
119
120    fn close(self: Box<Self>) -> BoxFuture<'static, Result<(), BoxedError>> {
121        Box::pin(async move {
122            match Arc::try_unwrap(self.service) {
123                Ok(service) => {
124                    let _ = service.cancel().await;
125                }
126                Err(arc) => {
127                    arc.cancellation_token().cancel();
128                }
129            }
130            Ok(())
131        })
132    }
133}
134
135struct MCPRemoteTool<TCtx>
136where
137    TCtx: Send + Sync + 'static,
138{
139    service: Weak<MCPRunningService<TCtx>>,
140    name: String,
141    description: String,
142    parameters: llm_sdk::JSONSchema,
143}
144
145impl<TCtx> MCPRemoteTool<TCtx>
146where
147    TCtx: Send + Sync + 'static,
148{
149    fn new(service: &Arc<MCPRunningService<TCtx>>, tool: Tool) -> Self {
150        let parameters =
151            Value::Object(Arc::try_unwrap(tool.input_schema).unwrap_or_else(|arc| (*arc).clone()));
152        let description = tool
153            .description
154            .map(std::borrow::Cow::into_owned)
155            .unwrap_or_default();
156        Self {
157            service: Arc::downgrade(service),
158            name: tool.name.into_owned(),
159            description,
160            parameters,
161        }
162    }
163}
164
165impl<TCtx> AgentTool<TCtx> for MCPRemoteTool<TCtx>
166where
167    TCtx: Send + Sync + 'static,
168{
169    fn name(&self) -> String {
170        self.name.clone()
171    }
172
173    fn description(&self) -> String {
174        self.description.clone()
175    }
176
177    fn parameters(&self) -> llm_sdk::JSONSchema {
178        self.parameters.clone()
179    }
180
181    fn execute(
182        &self,
183        args: Value,
184        _context: &TCtx,
185        _state: &RunState,
186    ) -> BoxFuture<'_, Result<AgentToolResult, BoxedError>> {
187        Box::pin(async move {
188            let arguments = match args {
189                Value::Null => None,
190                Value::Object(map) => Some(map),
191                other => {
192                    let message = format!("MCP tool arguments must be an object, received {other}");
193                    return Err(
194                        Box::new(IoError::new(ErrorKind::InvalidInput, message)) as BoxedError
195                    );
196                }
197            };
198
199            let request = CallToolRequestParam {
200                name: self.name.clone().into(),
201                arguments,
202            };
203
204            let Some(service) = self.service.upgrade() else {
205                return Err(Box::new(IoError::new(
206                    ErrorKind::NotConnected,
207                    "MCP service not initialised",
208                )) as BoxedError);
209            };
210            let result = service
211                .call_tool(request)
212                .await
213                .map_err(|err| Box::new(err) as BoxedError)?;
214
215            let CallToolResult {
216                content, is_error, ..
217            } = result;
218
219            let content = convert_mcp_content(content)?;
220            let is_error = is_error.unwrap_or(false);
221
222            Ok(AgentToolResult { content, is_error })
223        })
224    }
225}
226
227// Remove "Bearer " or "bearer " prefix if present because the rmcp library
228// already adds it.
229fn strip_bearer_prefix(token: &str) -> String {
230    let trimmed = token.trim();
231    if let Some(rest) = trimmed.strip_prefix("Bearer ") {
232        rest.to_string()
233    } else if let Some(rest) = trimmed.strip_prefix("bearer ") {
234        rest.to_string()
235    } else {
236        trimmed.to_string()
237    }
238}
239
240#[allow(clippy::type_complexity)]
241struct MCPToolkitState<TCtx>
242where
243    TCtx: Send + Sync + 'static,
244{
245    service: Arc<OnceLock<Weak<MCPRunningService<TCtx>>>>,
246    tools: Arc<RwLock<Result<Vec<Arc<dyn AgentTool<TCtx>>>, String>>>,
247}
248
249impl<TCtx> Clone for MCPToolkitState<TCtx>
250where
251    TCtx: Send + Sync + 'static,
252{
253    fn clone(&self) -> Self {
254        Self {
255            service: Arc::clone(&self.service),
256            tools: Arc::clone(&self.tools),
257        }
258    }
259}
260
261impl<TCtx> MCPToolkitState<TCtx>
262where
263    TCtx: Send + Sync + 'static,
264{
265    fn new() -> Self {
266        Self {
267            service: Arc::new(OnceLock::new()),
268            tools: Arc::new(RwLock::new(Ok(Vec::new()))),
269        }
270    }
271
272    fn register_service(&self, service: &Arc<MCPRunningService<TCtx>>) {
273        let _ = self.service.set(Arc::downgrade(service));
274    }
275
276    async fn refresh(&self) -> Result<(), BoxedError> {
277        let service = self.service()?;
278        self.refresh_with(&service).await
279    }
280
281    async fn refresh_with(&self, service: &Arc<MCPRunningService<TCtx>>) -> Result<(), BoxedError> {
282        let specs = service
283            .peer()
284            .list_all_tools()
285            .await
286            .map_err(|err| Box::new(err) as BoxedError)?;
287
288        let mut new_tools: Vec<Arc<dyn AgentTool<TCtx>>> = Vec::with_capacity(specs.len());
289        for spec in specs {
290            let remote = MCPRemoteTool::new(service, spec);
291            new_tools.push(Arc::new(remote));
292        }
293
294        let mut guard = self.tools.write().expect("tool registry lock poisoned");
295        *guard = Ok(new_tools);
296        Ok(())
297    }
298
299    fn tools(&self) -> Vec<Arc<dyn AgentTool<TCtx>>> {
300        let guard = self.tools.read().expect("tool registry lock poisoned");
301        match guard.as_ref() {
302            Ok(tools) => tools.clone(),
303            Err(message) => panic!("mcp tool discovery failed: {message}"),
304        }
305    }
306
307    fn record_error<E>(&self, err: E)
308    where
309        E: std::fmt::Display,
310    {
311        if let Ok(mut guard) = self.tools.write() {
312            *guard = Err(err.to_string());
313        }
314    }
315
316    fn service(&self) -> Result<Arc<MCPRunningService<TCtx>>, BoxedError> {
317        self.service
318            .get()
319            .and_then(Weak::upgrade)
320            .ok_or_else(|| -> BoxedError {
321                Box::new(IoError::new(
322                    ErrorKind::NotConnected,
323                    "MCP service not initialised",
324                ))
325            })
326    }
327}
328
329impl<TCtx> ClientHandler for MCPToolkitState<TCtx>
330where
331    TCtx: Send + Sync + 'static,
332{
333    fn on_tool_list_changed(
334        &self,
335        _context: NotificationContext<RoleClient>,
336    ) -> impl std::future::Future<Output = ()> + Send + '_ {
337        let state = self.clone();
338        async move {
339            if let Err(err) = state.refresh().await {
340                state.record_error(err);
341            }
342        }
343    }
344}