Skip to main content

plexus_core/
mcp_bridge.rs

1//! MCP server bridge using rmcp with Plexus backend
2//!
3//! This module implements the MCP protocol using the rmcp crate,
4//! bridging MCP tool calls to Plexus activation methods.
5
6use std::sync::Arc;
7
8use futures::StreamExt;
9use rmcp::{
10    ErrorData as McpError,
11    ServerHandler,
12    model::*,
13    service::{RequestContext, RoleServer},
14};
15use serde_json::json;
16
17use crate::plexus::bidirectional::{handle_pending_response, BidirError};
18use crate::plexus::types::PlexusStreamItem;
19use crate::plexus::{DynamicHub, PlexusError, PluginSchema};
20
21// =============================================================================
22// Schema Transformation
23// =============================================================================
24
25/// Convert Plexus activation schemas to rmcp Tool format
26///
27/// MCP requires all tool inputSchema to have "type": "object" at root.
28/// schemars may produce schemas without this (e.g., for unit types).
29fn schemas_to_rmcp_tools(schemas: Vec<PluginSchema>) -> Vec<Tool> {
30    let mut tools: Vec<Tool> = schemas
31        .into_iter()
32        .flat_map(|activation| {
33            let namespace = activation.namespace.clone();
34            activation.methods.into_iter().map(move |method| {
35                let name = format!("{}.{}", namespace, method.name);
36                let description = method.description.clone();
37
38                // Convert schemars::Schema to JSON, ensure "type": "object" exists
39                let input_schema = method
40                    .params
41                    .and_then(|s| serde_json::to_value(s).ok())
42                    .and_then(|v| v.as_object().cloned())
43                    .map(|mut obj| {
44                        // MCP requires "type": "object" at schema root
45                        if !obj.contains_key("type") {
46                            obj.insert("type".to_string(), json!("object"));
47                        }
48                        Arc::new(obj)
49                    })
50                    .unwrap_or_else(|| {
51                        // Empty params = empty object schema
52                        Arc::new(serde_json::Map::from_iter([
53                            ("type".to_string(), json!("object")),
54                        ]))
55                    });
56
57                Tool::new(name, description, input_schema)
58            })
59        })
60        .collect();
61
62    // Add the _plexus_respond tool for bidirectional communication
63    tools.push(create_plexus_respond_tool());
64
65    tools
66}
67
68/// Create the _plexus_respond tool for bidirectional communication
69///
70/// This tool allows MCP clients to respond to bidirectional requests
71/// sent via logging notifications (type: "request").
72fn create_plexus_respond_tool() -> Tool {
73    let schema = Arc::new(serde_json::Map::from_iter([
74        ("type".to_string(), json!("object")),
75        (
76            "properties".to_string(),
77            json!({
78                "request_id": {
79                    "type": "string",
80                    "description": "The request_id from the bidirectional request notification"
81                },
82                "response_data": {
83                    "description": "The response data to send back to the server"
84                }
85            }),
86        ),
87        (
88            "required".to_string(),
89            json!(["request_id", "response_data"]),
90        ),
91    ]));
92
93    Tool::new(
94        "_plexus_respond".to_string(),
95        "Respond to a bidirectional request from the server. \
96         When you receive a logging notification with type 'request', \
97         use this tool to send your response back."
98            .to_string(),
99        schema,
100    )
101}
102
103// =============================================================================
104// Error Mapping
105// =============================================================================
106
107/// Convert PlexusError to McpError
108fn plexus_to_mcp_error(e: PlexusError) -> McpError {
109    match e {
110        PlexusError::ActivationNotFound(name) => {
111            McpError::invalid_params(format!("Unknown activation: {}", name), None)
112        }
113        PlexusError::MethodNotFound { activation, method } => {
114            McpError::invalid_params(format!("Unknown method: {}.{}", activation, method), None)
115        }
116        PlexusError::InvalidParams(reason) => McpError::invalid_params(reason, None),
117        PlexusError::ExecutionError(error) => McpError::internal_error(error, None),
118        PlexusError::HandleNotSupported(activation) => {
119            McpError::invalid_params(format!("Handle resolution not supported: {}", activation), None)
120        }
121        PlexusError::TransportError(kind) => {
122            McpError::internal_error(format!("Transport error: {}", kind), None)
123        }
124        PlexusError::Unauthenticated(reason) => {
125            McpError::invalid_request(format!("Authentication required: {}", reason), None)
126        }
127    }
128}
129
130// =============================================================================
131// Plexus MCP Bridge
132// =============================================================================
133
134/// MCP handler that bridges to Plexus RPC hub
135#[derive(Clone)]
136pub struct PlexusMcpBridge {
137    hub: Arc<DynamicHub>,
138}
139
140impl PlexusMcpBridge {
141    pub fn new(hub: Arc<DynamicHub>) -> Self {
142        Self { hub }
143    }
144
145    /// Handle the _plexus_respond tool call
146    ///
147    /// Routes the response back to the waiting BidirChannel via the global registry.
148    async fn handle_plexus_respond(
149        &self,
150        request: CallToolRequestParam,
151    ) -> Result<CallToolResult, McpError> {
152        let arguments = request
153            .arguments
154            .map(serde_json::Value::Object)
155            .unwrap_or(json!({}));
156
157        // Extract request_id and response_data
158        let request_id = arguments
159            .get("request_id")
160            .and_then(|v| v.as_str())
161            .ok_or_else(|| McpError::invalid_params("Missing required parameter: request_id", None))?
162            .to_string();
163
164        let response_data = arguments
165            .get("response_data")
166            .cloned()
167            .ok_or_else(|| {
168                McpError::invalid_params("Missing required parameter: response_data", None)
169            })?;
170
171        tracing::debug!(
172            request_id = %request_id,
173            "Handling _plexus_respond"
174        );
175
176        // Forward response through global registry
177        match handle_pending_response(&request_id, response_data) {
178            Ok(()) => Ok(CallToolResult::success(vec![Content::text(
179                "Response delivered successfully",
180            )])),
181            Err(BidirError::UnknownRequest) => {
182                tracing::warn!(request_id = %request_id, "Unknown request ID in _plexus_respond");
183                Err(McpError::invalid_params(
184                    format!("Unknown request ID: {}. The request may have timed out or been cancelled.", request_id),
185                    None,
186                ))
187            }
188            Err(BidirError::ChannelClosed) => {
189                tracing::warn!(request_id = %request_id, "Channel closed in _plexus_respond");
190                Err(McpError::internal_error(
191                    "Response channel was closed (request may have timed out)",
192                    None,
193                ))
194            }
195            Err(e) => {
196                tracing::error!(request_id = %request_id, error = ?e, "Error in _plexus_respond");
197                Err(McpError::internal_error(format!("Failed to deliver response: {}", e), None))
198            }
199        }
200    }
201}
202
203impl ServerHandler for PlexusMcpBridge {
204    fn get_info(&self) -> ServerInfo {
205        ServerInfo {
206            protocol_version: ProtocolVersion::LATEST,
207            capabilities: ServerCapabilities::builder()
208                .enable_tools()
209                .enable_logging()
210                .build(),
211            server_info: Implementation::from_build_env(),
212            instructions: Some(
213                "Plexus MCP server - provides access to all registered activations.".into(),
214            ),
215        }
216    }
217
218    async fn list_tools(
219        &self,
220        _request: Option<PaginatedRequestParam>,
221        _ctx: RequestContext<RoleServer>,
222    ) -> Result<ListToolsResult, McpError> {
223        let schemas = self.hub.list_plugin_schemas();
224        let tools = schemas_to_rmcp_tools(schemas);
225
226        tracing::debug!("Listing {} tools", tools.len());
227
228        Ok(ListToolsResult {
229            tools,
230            next_cursor: None,
231            meta: None,
232        })
233    }
234
235    async fn call_tool(
236        &self,
237        request: CallToolRequestParam,
238        ctx: RequestContext<RoleServer>,
239    ) -> Result<CallToolResult, McpError> {
240        let method_name = &request.name;
241
242        // Handle _plexus_respond tool specially
243        if method_name == "_plexus_respond" {
244            return self.handle_plexus_respond(request).await;
245        }
246
247        let arguments = request
248            .arguments
249            .map(serde_json::Value::Object)
250            .unwrap_or(json!({}));
251
252        tracing::debug!("Calling tool: {} with args: {:?}", method_name, arguments);
253
254        // Get progress token if provided
255        let progress_token = ctx.meta.get_progress_token();
256
257        // Logger name: plexus.namespace.method (e.g., plexus.bash.execute)
258        let logger = format!("plexus.{}", method_name);
259
260        // Call Plexus RPC hub and get stream
261        let stream = self
262            .hub
263            .route(method_name, arguments, None)
264            .await
265            .map_err(plexus_to_mcp_error)?;
266
267        // Stream events via notifications AND buffer for final result
268        let mut had_error = false;
269        let mut buffered_data: Vec<serde_json::Value> = Vec::new();
270        let mut error_messages: Vec<String> = Vec::new();
271
272        tokio::pin!(stream);
273        while let Some(item) = stream.next().await {
274            // Check cancellation on each iteration
275            if ctx.ct.is_cancelled() {
276                return Err(McpError::internal_error("Cancelled", None));
277            }
278
279            match &item {
280                PlexusStreamItem::Progress {
281                    message,
282                    percentage,
283                    ..
284                } => {
285                    // Only send progress if client provided token
286                    if let Some(ref token) = progress_token {
287                        let _ = ctx
288                            .peer
289                            .notify_progress(ProgressNotificationParam {
290                                progress_token: token.clone(),
291                                progress: percentage.unwrap_or(0.0) as f64,
292                                total: None,
293                                message: Some(message.clone()),
294                            })
295                            .await;
296                    }
297                }
298
299                PlexusStreamItem::Data {
300                    content, content_type, ..
301                } => {
302                    // Buffer data for final result
303                    buffered_data.push(content.clone());
304
305                    // Also stream via notifications for real-time consumers
306                    let _ = ctx
307                        .peer
308                        .notify_logging_message(LoggingMessageNotificationParam {
309                            level: LoggingLevel::Info,
310                            logger: Some(logger.clone()),
311                            data: json!({
312                                "type": "data",
313                                "content_type": content_type,
314                                "data": content,
315                            }),
316                        })
317                        .await;
318                }
319
320                PlexusStreamItem::Error {
321                    message, recoverable, ..
322                } => {
323                    // Buffer errors for final result
324                    error_messages.push(message.clone());
325
326                    let _ = ctx
327                        .peer
328                        .notify_logging_message(LoggingMessageNotificationParam {
329                            level: LoggingLevel::Error,
330                            logger: Some(logger.clone()),
331                            data: json!({
332                                "type": "error",
333                                "error": message,
334                                "recoverable": recoverable,
335                            }),
336                        })
337                        .await;
338
339                    if !recoverable {
340                        had_error = true;
341                    }
342                }
343
344                PlexusStreamItem::Request {
345                    request_id,
346                    request_data,
347                    timeout_ms,
348                } => {
349                    // Send bidirectional request as logging notification
350                    // Client responds via _plexus_respond tool call
351                    tracing::debug!(
352                        request_id = %request_id,
353                        timeout_ms = timeout_ms,
354                        "Sending bidirectional request notification"
355                    );
356
357                    let _ = ctx
358                        .peer
359                        .notify_logging_message(LoggingMessageNotificationParam {
360                            level: LoggingLevel::Info,
361                            logger: Some("plexus.bidir".into()),
362                            data: json!({
363                                "type": "request",
364                                "request_id": request_id,
365                                "request_data": request_data,
366                                "timeout_ms": timeout_ms,
367                            }),
368                        })
369                        .await;
370                }
371
372                PlexusStreamItem::Done { .. } => {
373                    break;
374                }
375            }
376        }
377
378        // Return buffered data in the final result
379        if had_error {
380            let error_content = if error_messages.is_empty() {
381                "Stream completed with errors".to_string()
382            } else {
383                error_messages.join("\n")
384            };
385            Ok(CallToolResult::error(vec![Content::text(error_content)]))
386        } else {
387            // Convert buffered data to content
388            let text_content = if buffered_data.is_empty() {
389                "(no output)".to_string()
390            } else if buffered_data.len() == 1 {
391                // Single value - return as text if string, otherwise JSON
392                match &buffered_data[0] {
393                    serde_json::Value::String(s) => s.clone(),
394                    other => serde_json::to_string_pretty(other).unwrap_or_default(),
395                }
396            } else {
397                // Multiple values - join strings or return as JSON array
398                let all_strings = buffered_data.iter().all(|v| v.is_string());
399                if all_strings {
400                    buffered_data
401                        .iter()
402                        .filter_map(|v| v.as_str())
403                        .collect::<Vec<_>>()
404                        .join("")
405                } else {
406                    serde_json::to_string_pretty(&buffered_data).unwrap_or_default()
407                }
408            };
409
410            // Estimate tokens (~4 chars per token for JSON/text)
411            let approx_tokens = (text_content.len() + 3) / 4;
412            let content_with_tokens = format!(
413                "{}\n\n[~{} tokens]",
414                text_content,
415                approx_tokens
416            );
417
418            Ok(CallToolResult::success(vec![Content::text(content_with_tokens)]))
419        }
420    }
421}