1use adk_core::{AdkError, ReadonlyContext, Result, Tool, ToolContext, Toolset};
10use async_trait::async_trait;
11use rmcp::{
12 model::{CallToolRequestParam, RawContent, ResourceContents},
13 service::RunningService,
14 RoleClient,
15};
16use serde_json::{json, Value};
17use std::ops::Deref;
18use std::sync::Arc;
19use tokio::sync::Mutex;
20
21pub type ToolFilter = Arc<dyn Fn(&str) -> bool + Send + Sync>;
23
24pub struct McpToolset<S = ()>
56where
57 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
58{
59 client: Arc<Mutex<RunningService<RoleClient, S>>>,
61 tool_filter: Option<ToolFilter>,
63 name: String,
65}
66
67impl<S> McpToolset<S>
68where
69 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
70{
71 pub fn new(client: RunningService<RoleClient, S>) -> Self {
89 Self {
90 client: Arc::new(Mutex::new(client)),
91 tool_filter: None,
92 name: "mcp_toolset".to_string(),
93 }
94 }
95
96 pub fn with_name(mut self, name: impl Into<String>) -> Self {
98 self.name = name.into();
99 self
100 }
101
102 pub fn with_filter<F>(mut self, filter: F) -> Self
116 where
117 F: Fn(&str) -> bool + Send + Sync + 'static,
118 {
119 self.tool_filter = Some(Arc::new(filter));
120 self
121 }
122
123 pub fn with_tools(self, tool_names: &[&str]) -> Self {
132 let names: Vec<String> = tool_names.iter().map(|s| s.to_string()).collect();
133 self.with_filter(move |name| names.iter().any(|n| n == name))
134 }
135}
136
137#[async_trait]
138impl<S> Toolset for McpToolset<S>
139where
140 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
141{
142 fn name(&self) -> &str {
143 &self.name
144 }
145
146 async fn tools(&self, _ctx: Arc<dyn ReadonlyContext>) -> Result<Vec<Arc<dyn Tool>>> {
147 let client = self.client.lock().await;
148
149 let mcp_tools = client
151 .list_all_tools()
152 .await
153 .map_err(|e| AdkError::Tool(format!("Failed to list MCP tools: {}", e)))?;
154
155 let mut tools: Vec<Arc<dyn Tool>> = Vec::new();
157
158 for mcp_tool in mcp_tools {
159 let tool_name = mcp_tool.name.to_string();
160
161 if let Some(ref filter) = self.tool_filter {
163 if !filter(&tool_name) {
164 continue;
165 }
166 }
167
168 let adk_tool = McpTool {
169 name: tool_name,
170 description: mcp_tool.description.map(|d| d.to_string()).unwrap_or_default(),
171 input_schema: Some(Value::Object(mcp_tool.input_schema.as_ref().clone())),
172 output_schema: mcp_tool.output_schema.map(|s| Value::Object(s.as_ref().clone())),
173 client: self.client.clone(),
174 };
175
176 tools.push(Arc::new(adk_tool) as Arc<dyn Tool>);
177 }
178
179 Ok(tools)
180 }
181}
182
183struct McpTool<S>
187where
188 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
189{
190 name: String,
191 description: String,
192 input_schema: Option<Value>,
193 output_schema: Option<Value>,
194 client: Arc<Mutex<RunningService<RoleClient, S>>>,
195}
196
197#[async_trait]
198impl<S> Tool for McpTool<S>
199where
200 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
201{
202 fn name(&self) -> &str {
203 &self.name
204 }
205
206 fn description(&self) -> &str {
207 &self.description
208 }
209
210 fn is_long_running(&self) -> bool {
211 false
212 }
213
214 fn parameters_schema(&self) -> Option<Value> {
215 self.input_schema.clone()
216 }
217
218 fn response_schema(&self) -> Option<Value> {
219 self.output_schema.clone()
220 }
221
222 async fn execute(&self, _ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
223 let client = self.client.lock().await;
224
225 let result = client
227 .call_tool(CallToolRequestParam {
228 name: self.name.clone().into(),
229 arguments: if args.is_null() || args == json!({}) {
230 None
231 } else {
232 match args {
234 Value::Object(map) => Some(map),
235 _ => {
236 return Err(AdkError::Tool(
237 "Tool arguments must be an object".to_string(),
238 ))
239 }
240 }
241 },
242 })
243 .await
244 .map_err(|e| {
245 AdkError::Tool(format!("Failed to call MCP tool '{}': {}", self.name, e))
246 })?;
247
248 if result.is_error.unwrap_or(false) {
250 let mut error_msg = format!("MCP tool '{}' execution failed", self.name);
251
252 for content in &result.content {
254 if let Some(text_content) = content.deref().as_text() {
256 error_msg.push_str(": ");
257 error_msg.push_str(&text_content.text);
258 break;
259 }
260 }
261
262 return Err(AdkError::Tool(error_msg));
263 }
264
265 if let Some(structured) = result.structured_content {
267 return Ok(json!({ "output": structured }));
268 }
269
270 let mut text_parts: Vec<String> = Vec::new();
272
273 for content in &result.content {
274 let raw: &RawContent = content.deref();
276 match raw {
277 RawContent::Text(text_content) => {
278 text_parts.push(text_content.text.clone());
279 }
280 RawContent::Image(image_content) => {
281 text_parts.push(format!(
283 "[Image: {} bytes, mime: {}]",
284 image_content.data.len(),
285 image_content.mime_type
286 ));
287 }
288 RawContent::Resource(resource_content) => {
289 let uri = match &resource_content.resource {
290 ResourceContents::TextResourceContents { uri, .. } => uri,
291 ResourceContents::BlobResourceContents { uri, .. } => uri,
292 };
293 text_parts.push(format!("[Resource: {}]", uri));
294 }
295 RawContent::Audio(_) => {
296 text_parts.push("[Audio content]".to_string());
297 }
298 RawContent::ResourceLink(link) => {
299 text_parts.push(format!("[ResourceLink: {}]", link.uri));
300 }
301 }
302 }
303
304 if text_parts.is_empty() {
305 return Err(AdkError::Tool(format!("MCP tool '{}' returned no content", self.name)));
306 }
307
308 Ok(json!({ "output": text_parts.join("\n") }))
309 }
310}
311
312unsafe impl<S> Send for McpTool<S> where
314 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static
315{
316}
317unsafe impl<S> Sync for McpTool<S> where
318 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static
319{
320}