1use std::sync::Arc;
7
8use futures::StreamExt;
9use rmcp::{
10 ErrorData as McpError,
11 ServerHandler,
12 model::*,
13 service::{RequestContext, RoleServer},
14};
15use serde_json::json;
16
17use crate::plexus::bidirectional::{handle_pending_response, BidirError};
18use crate::plexus::types::PlexusStreamItem;
19use crate::plexus::{DynamicHub, PlexusError, PluginSchema};
20
21fn schemas_to_rmcp_tools(schemas: Vec<PluginSchema>) -> Vec<Tool> {
30 let mut tools: Vec<Tool> = schemas
31 .into_iter()
32 .flat_map(|activation| {
33 let namespace = activation.namespace.clone();
34 activation.methods.into_iter().map(move |method| {
35 let name = format!("{}.{}", namespace, method.name);
36 let description = method.description.clone();
37
38 let input_schema = method
40 .params
41 .and_then(|s| serde_json::to_value(s).ok())
42 .and_then(|v| v.as_object().cloned())
43 .map(|mut obj| {
44 if !obj.contains_key("type") {
46 obj.insert("type".to_string(), json!("object"));
47 }
48 Arc::new(obj)
49 })
50 .unwrap_or_else(|| {
51 Arc::new(serde_json::Map::from_iter([
53 ("type".to_string(), json!("object")),
54 ]))
55 });
56
57 Tool::new(name, description, input_schema)
58 })
59 })
60 .collect();
61
62 tools.push(create_plexus_respond_tool());
64
65 tools
66}
67
68fn create_plexus_respond_tool() -> Tool {
73 let schema = Arc::new(serde_json::Map::from_iter([
74 ("type".to_string(), json!("object")),
75 (
76 "properties".to_string(),
77 json!({
78 "request_id": {
79 "type": "string",
80 "description": "The request_id from the bidirectional request notification"
81 },
82 "response_data": {
83 "description": "The response data to send back to the server"
84 }
85 }),
86 ),
87 (
88 "required".to_string(),
89 json!(["request_id", "response_data"]),
90 ),
91 ]));
92
93 Tool::new(
94 "_plexus_respond".to_string(),
95 "Respond to a bidirectional request from the server. \
96 When you receive a logging notification with type 'request', \
97 use this tool to send your response back."
98 .to_string(),
99 schema,
100 )
101}
102
103fn plexus_to_mcp_error(e: PlexusError) -> McpError {
109 match e {
110 PlexusError::ActivationNotFound(name) => {
111 McpError::invalid_params(format!("Unknown activation: {}", name), None)
112 }
113 PlexusError::MethodNotFound { activation, method } => {
114 McpError::invalid_params(format!("Unknown method: {}.{}", activation, method), None)
115 }
116 PlexusError::InvalidParams(reason) => McpError::invalid_params(reason, None),
117 PlexusError::ExecutionError(error) => McpError::internal_error(error, None),
118 PlexusError::HandleNotSupported(activation) => {
119 McpError::invalid_params(format!("Handle resolution not supported: {}", activation), None)
120 }
121 PlexusError::TransportError(kind) => {
122 McpError::internal_error(format!("Transport error: {}", kind), None)
123 }
124 PlexusError::Unauthenticated(reason) => {
125 McpError::invalid_request(format!("Authentication required: {}", reason), None)
126 }
127 }
128}
129
130#[derive(Clone)]
136pub struct PlexusMcpBridge {
137 hub: Arc<DynamicHub>,
138}
139
140impl PlexusMcpBridge {
141 pub fn new(hub: Arc<DynamicHub>) -> Self {
142 Self { hub }
143 }
144
145 async fn handle_plexus_respond(
149 &self,
150 request: CallToolRequestParam,
151 ) -> Result<CallToolResult, McpError> {
152 let arguments = request
153 .arguments
154 .map(serde_json::Value::Object)
155 .unwrap_or(json!({}));
156
157 let request_id = arguments
159 .get("request_id")
160 .and_then(|v| v.as_str())
161 .ok_or_else(|| McpError::invalid_params("Missing required parameter: request_id", None))?
162 .to_string();
163
164 let response_data = arguments
165 .get("response_data")
166 .cloned()
167 .ok_or_else(|| {
168 McpError::invalid_params("Missing required parameter: response_data", None)
169 })?;
170
171 tracing::debug!(
172 request_id = %request_id,
173 "Handling _plexus_respond"
174 );
175
176 match handle_pending_response(&request_id, response_data) {
178 Ok(()) => Ok(CallToolResult::success(vec![Content::text(
179 "Response delivered successfully",
180 )])),
181 Err(BidirError::UnknownRequest) => {
182 tracing::warn!(request_id = %request_id, "Unknown request ID in _plexus_respond");
183 Err(McpError::invalid_params(
184 format!("Unknown request ID: {}. The request may have timed out or been cancelled.", request_id),
185 None,
186 ))
187 }
188 Err(BidirError::ChannelClosed) => {
189 tracing::warn!(request_id = %request_id, "Channel closed in _plexus_respond");
190 Err(McpError::internal_error(
191 "Response channel was closed (request may have timed out)",
192 None,
193 ))
194 }
195 Err(e) => {
196 tracing::error!(request_id = %request_id, error = ?e, "Error in _plexus_respond");
197 Err(McpError::internal_error(format!("Failed to deliver response: {}", e), None))
198 }
199 }
200 }
201}
202
203impl ServerHandler for PlexusMcpBridge {
204 fn get_info(&self) -> ServerInfo {
205 ServerInfo {
206 protocol_version: ProtocolVersion::LATEST,
207 capabilities: ServerCapabilities::builder()
208 .enable_tools()
209 .enable_logging()
210 .build(),
211 server_info: Implementation::from_build_env(),
212 instructions: Some(
213 "Plexus MCP server - provides access to all registered activations.".into(),
214 ),
215 }
216 }
217
218 async fn list_tools(
219 &self,
220 _request: Option<PaginatedRequestParam>,
221 _ctx: RequestContext<RoleServer>,
222 ) -> Result<ListToolsResult, McpError> {
223 let schemas = self.hub.list_plugin_schemas();
224 let tools = schemas_to_rmcp_tools(schemas);
225
226 tracing::debug!("Listing {} tools", tools.len());
227
228 Ok(ListToolsResult {
229 tools,
230 next_cursor: None,
231 meta: None,
232 })
233 }
234
235 async fn call_tool(
236 &self,
237 request: CallToolRequestParam,
238 ctx: RequestContext<RoleServer>,
239 ) -> Result<CallToolResult, McpError> {
240 let method_name = &request.name;
241
242 if method_name == "_plexus_respond" {
244 return self.handle_plexus_respond(request).await;
245 }
246
247 let arguments = request
248 .arguments
249 .map(serde_json::Value::Object)
250 .unwrap_or(json!({}));
251
252 tracing::debug!("Calling tool: {} with args: {:?}", method_name, arguments);
253
254 let progress_token = ctx.meta.get_progress_token();
256
257 let logger = format!("plexus.{}", method_name);
259
260 let stream = self
262 .hub
263 .route(method_name, arguments, None)
264 .await
265 .map_err(plexus_to_mcp_error)?;
266
267 let mut had_error = false;
269 let mut buffered_data: Vec<serde_json::Value> = Vec::new();
270 let mut error_messages: Vec<String> = Vec::new();
271
272 tokio::pin!(stream);
273 while let Some(item) = stream.next().await {
274 if ctx.ct.is_cancelled() {
276 return Err(McpError::internal_error("Cancelled", None));
277 }
278
279 match &item {
280 PlexusStreamItem::Progress {
281 message,
282 percentage,
283 ..
284 } => {
285 if let Some(ref token) = progress_token {
287 let _ = ctx
288 .peer
289 .notify_progress(ProgressNotificationParam {
290 progress_token: token.clone(),
291 progress: percentage.unwrap_or(0.0) as f64,
292 total: None,
293 message: Some(message.clone()),
294 })
295 .await;
296 }
297 }
298
299 PlexusStreamItem::Data {
300 content, content_type, ..
301 } => {
302 buffered_data.push(content.clone());
304
305 let _ = ctx
307 .peer
308 .notify_logging_message(LoggingMessageNotificationParam {
309 level: LoggingLevel::Info,
310 logger: Some(logger.clone()),
311 data: json!({
312 "type": "data",
313 "content_type": content_type,
314 "data": content,
315 }),
316 })
317 .await;
318 }
319
320 PlexusStreamItem::Error {
321 message, recoverable, ..
322 } => {
323 error_messages.push(message.clone());
325
326 let _ = ctx
327 .peer
328 .notify_logging_message(LoggingMessageNotificationParam {
329 level: LoggingLevel::Error,
330 logger: Some(logger.clone()),
331 data: json!({
332 "type": "error",
333 "error": message,
334 "recoverable": recoverable,
335 }),
336 })
337 .await;
338
339 if !recoverable {
340 had_error = true;
341 }
342 }
343
344 PlexusStreamItem::Request {
345 request_id,
346 request_data,
347 timeout_ms,
348 } => {
349 tracing::debug!(
352 request_id = %request_id,
353 timeout_ms = timeout_ms,
354 "Sending bidirectional request notification"
355 );
356
357 let _ = ctx
358 .peer
359 .notify_logging_message(LoggingMessageNotificationParam {
360 level: LoggingLevel::Info,
361 logger: Some("plexus.bidir".into()),
362 data: json!({
363 "type": "request",
364 "request_id": request_id,
365 "request_data": request_data,
366 "timeout_ms": timeout_ms,
367 }),
368 })
369 .await;
370 }
371
372 PlexusStreamItem::Done { .. } => {
373 break;
374 }
375 }
376 }
377
378 if had_error {
380 let error_content = if error_messages.is_empty() {
381 "Stream completed with errors".to_string()
382 } else {
383 error_messages.join("\n")
384 };
385 Ok(CallToolResult::error(vec![Content::text(error_content)]))
386 } else {
387 let text_content = if buffered_data.is_empty() {
389 "(no output)".to_string()
390 } else if buffered_data.len() == 1 {
391 match &buffered_data[0] {
393 serde_json::Value::String(s) => s.clone(),
394 other => serde_json::to_string_pretty(other).unwrap_or_default(),
395 }
396 } else {
397 let all_strings = buffered_data.iter().all(|v| v.is_string());
399 if all_strings {
400 buffered_data
401 .iter()
402 .filter_map(|v| v.as_str())
403 .collect::<Vec<_>>()
404 .join("")
405 } else {
406 serde_json::to_string_pretty(&buffered_data).unwrap_or_default()
407 }
408 };
409
410 let approx_tokens = (text_content.len() + 3) / 4;
412 let content_with_tokens = format!(
413 "{}\n\n[~{} tokens]",
414 text_content,
415 approx_tokens
416 );
417
418 Ok(CallToolResult::success(vec![Content::text(content_with_tokens)]))
419 }
420 }
421}