1use std::sync::Arc;
7
8use futures::StreamExt;
9use rmcp::{
10 ErrorData as McpError,
11 ServerHandler,
12 model::*,
13 service::{RequestContext, RoleServer},
14};
15use serde_json::json;
16
17use crate::plexus::{DynamicHub, PlexusError, PluginSchema};
18use crate::plexus::types::PlexusStreamItem;
19
20fn schemas_to_rmcp_tools(schemas: Vec<PluginSchema>) -> Vec<Tool> {
29 schemas
30 .into_iter()
31 .flat_map(|activation| {
32 let namespace = activation.namespace.clone();
33 activation.methods.into_iter().map(move |method| {
34 let name = format!("{}.{}", namespace, method.name);
35 let description = method.description.clone();
36
37 let input_schema = method
39 .params
40 .and_then(|s| serde_json::to_value(s).ok())
41 .and_then(|v| v.as_object().cloned())
42 .map(|mut obj| {
43 if !obj.contains_key("type") {
45 obj.insert("type".to_string(), json!("object"));
46 }
47 Arc::new(obj)
48 })
49 .unwrap_or_else(|| {
50 Arc::new(serde_json::Map::from_iter([
52 ("type".to_string(), json!("object")),
53 ]))
54 });
55
56 Tool::new(name, description, input_schema)
57 })
58 })
59 .collect()
60}
61
62fn plexus_to_mcp_error(e: PlexusError) -> McpError {
68 match e {
69 PlexusError::ActivationNotFound(name) => {
70 McpError::invalid_params(format!("Unknown activation: {}", name), None)
71 }
72 PlexusError::MethodNotFound { activation, method } => {
73 McpError::invalid_params(format!("Unknown method: {}.{}", activation, method), None)
74 }
75 PlexusError::InvalidParams(reason) => McpError::invalid_params(reason, None),
76 PlexusError::ExecutionError(error) => McpError::internal_error(error, None),
77 PlexusError::HandleNotSupported(activation) => {
78 McpError::invalid_params(format!("Handle resolution not supported: {}", activation), None)
79 }
80 PlexusError::TransportError(kind) => {
81 McpError::internal_error(format!("Transport error: {:?}", kind), None)
82 }
83 PlexusError::Unauthenticated(reason) => {
84 McpError::invalid_params(format!("Unauthenticated: {}", reason), None)
85 }
86 }
87}
88
89#[derive(Clone)]
95pub struct PlexusMcpBridge {
96 hub: Arc<DynamicHub>,
97}
98
99impl PlexusMcpBridge {
100 pub fn new(hub: Arc<DynamicHub>) -> Self {
101 Self { hub }
102 }
103}
104
105impl ServerHandler for PlexusMcpBridge {
106 fn get_info(&self) -> ServerInfo {
107 ServerInfo {
108 protocol_version: ProtocolVersion::LATEST,
109 capabilities: ServerCapabilities::builder()
110 .enable_tools()
111 .enable_logging()
112 .build(),
113 server_info: Implementation::from_build_env(),
114 instructions: Some(
115 "Plexus MCP server - provides access to all registered activations.".into(),
116 ),
117 }
118 }
119
120 async fn list_tools(
121 &self,
122 _request: Option<PaginatedRequestParam>,
123 _ctx: RequestContext<RoleServer>,
124 ) -> Result<ListToolsResult, McpError> {
125 let schemas = self.hub.list_plugin_schemas();
126 let tools = schemas_to_rmcp_tools(schemas);
127
128 tracing::debug!("Listing {} tools", tools.len());
129
130 Ok(ListToolsResult {
131 tools,
132 next_cursor: None,
133 meta: None,
134 })
135 }
136
137 async fn call_tool(
138 &self,
139 request: CallToolRequestParam,
140 ctx: RequestContext<RoleServer>,
141 ) -> Result<CallToolResult, McpError> {
142 let method_name = &request.name;
143 let arguments = request
144 .arguments
145 .map(serde_json::Value::Object)
146 .unwrap_or(json!({}));
147
148 tracing::debug!("Calling tool: {} with args: {:?}", method_name, arguments);
149
150 let progress_token = ctx.meta.get_progress_token();
152
153 let logger = format!("plexus.{}", method_name);
155
156 let stream = self
158 .hub
159 .route(method_name, arguments, None)
160 .await
161 .map_err(plexus_to_mcp_error)?;
162
163 let mut had_error = false;
165 let mut buffered_data: Vec<serde_json::Value> = Vec::new();
166 let mut error_messages: Vec<String> = Vec::new();
167
168 tokio::pin!(stream);
169 while let Some(item) = stream.next().await {
170 if ctx.ct.is_cancelled() {
172 return Err(McpError::internal_error("Cancelled", None));
173 }
174
175 match &item {
176 PlexusStreamItem::Progress {
177 message,
178 percentage,
179 ..
180 } => {
181 if let Some(ref token) = progress_token {
183 let _ = ctx
184 .peer
185 .notify_progress(ProgressNotificationParam {
186 progress_token: token.clone(),
187 progress: percentage.unwrap_or(0.0) as f64,
188 total: None,
189 message: Some(message.clone()),
190 })
191 .await;
192 }
193 }
194
195 PlexusStreamItem::Data {
196 content, content_type, ..
197 } => {
198 buffered_data.push(content.clone());
200
201 let _ = ctx
203 .peer
204 .notify_logging_message(LoggingMessageNotificationParam {
205 level: LoggingLevel::Info,
206 logger: Some(logger.clone()),
207 data: json!({
208 "type": "data",
209 "content_type": content_type,
210 "data": content,
211 }),
212 })
213 .await;
214 }
215
216 PlexusStreamItem::Error {
217 message, recoverable, ..
218 } => {
219 error_messages.push(message.clone());
221
222 let _ = ctx
223 .peer
224 .notify_logging_message(LoggingMessageNotificationParam {
225 level: LoggingLevel::Error,
226 logger: Some(logger.clone()),
227 data: json!({
228 "type": "error",
229 "error": message,
230 "recoverable": recoverable,
231 }),
232 })
233 .await;
234
235 if !recoverable {
236 had_error = true;
237 }
238 }
239
240 PlexusStreamItem::Done { .. } => {
241 break;
242 }
243
244 PlexusStreamItem::Request {
245 request_id,
246 request_data,
247 timeout_ms,
248 } => {
249 let _ = ctx
252 .peer
253 .notify_logging_message(LoggingMessageNotificationParam {
254 level: LoggingLevel::Info,
255 logger: Some(logger.clone()),
256 data: json!({
257 "type": "request",
258 "request_id": request_id,
259 "request_data": request_data,
260 "timeout_ms": timeout_ms,
261 }),
262 })
263 .await;
264 }
265 }
266 }
267
268 if had_error {
270 let error_content = if error_messages.is_empty() {
271 "Stream completed with errors".to_string()
272 } else {
273 error_messages.join("\n")
274 };
275 Ok(CallToolResult::error(vec![Content::text(error_content)]))
276 } else {
277 let text_content = if buffered_data.is_empty() {
279 "(no output)".to_string()
280 } else if buffered_data.len() == 1 {
281 match &buffered_data[0] {
283 serde_json::Value::String(s) => s.clone(),
284 other => serde_json::to_string_pretty(other).unwrap_or_default(),
285 }
286 } else {
287 let all_strings = buffered_data.iter().all(|v| v.is_string());
289 if all_strings {
290 buffered_data
291 .iter()
292 .filter_map(|v| v.as_str())
293 .collect::<Vec<_>>()
294 .join("")
295 } else {
296 serde_json::to_string_pretty(&buffered_data).unwrap_or_default()
297 }
298 };
299
300 Ok(CallToolResult::success(vec![Content::text(text_content)]))
301 }
302 }
303}