Skip to main content

plexus_substrate/
mcp_bridge.rs

1//! MCP server bridge using rmcp with Plexus RPC backend
2//!
3//! This module implements the MCP protocol using the rmcp crate,
4//! bridging MCP tool calls to Plexus RPC activation methods.
5
6use std::sync::Arc;
7
8use futures::StreamExt;
9use rmcp::{
10    ErrorData as McpError,
11    ServerHandler,
12    model::*,
13    service::{RequestContext, RoleServer},
14};
15use serde_json::json;
16
17use crate::plexus::{DynamicHub, PlexusError, PluginSchema};
18use crate::plexus::types::PlexusStreamItem;
19
20// =============================================================================
21// Schema Transformation
22// =============================================================================
23
24/// Convert Plexus RPC activation schemas to rmcp Tool format
25///
26/// MCP requires all tool inputSchema to have "type": "object" at root.
27/// schemars may produce schemas without this (e.g., for unit types).
28fn schemas_to_rmcp_tools(schemas: Vec<PluginSchema>) -> Vec<Tool> {
29    schemas
30        .into_iter()
31        .flat_map(|activation| {
32            let namespace = activation.namespace.clone();
33            activation.methods.into_iter().map(move |method| {
34                let name = format!("{}.{}", namespace, method.name);
35                let description = method.description.clone();
36
37                // Convert schemars::Schema to JSON, ensure "type": "object" exists
38                let input_schema = method
39                    .params
40                    .and_then(|s| serde_json::to_value(s).ok())
41                    .and_then(|v| v.as_object().cloned())
42                    .map(|mut obj| {
43                        // MCP requires "type": "object" at schema root
44                        if !obj.contains_key("type") {
45                            obj.insert("type".to_string(), json!("object"));
46                        }
47                        Arc::new(obj)
48                    })
49                    .unwrap_or_else(|| {
50                        // Empty params = empty object schema
51                        Arc::new(serde_json::Map::from_iter([
52                            ("type".to_string(), json!("object")),
53                        ]))
54                    });
55
56                Tool::new(name, description, input_schema)
57            })
58        })
59        .collect()
60}
61
62// =============================================================================
63// Error Mapping
64// =============================================================================
65
66/// Convert PlexusError to McpError
67fn plexus_to_mcp_error(e: PlexusError) -> McpError {
68    match e {
69        PlexusError::ActivationNotFound(name) => {
70            McpError::invalid_params(format!("Unknown activation: {}", name), None)
71        }
72        PlexusError::MethodNotFound { activation, method } => {
73            McpError::invalid_params(format!("Unknown method: {}.{}", activation, method), None)
74        }
75        PlexusError::InvalidParams(reason) => McpError::invalid_params(reason, None),
76        PlexusError::ExecutionError(error) => McpError::internal_error(error, None),
77        PlexusError::HandleNotSupported(activation) => {
78            McpError::invalid_params(format!("Handle resolution not supported: {}", activation), None)
79        }
80        PlexusError::TransportError(kind) => {
81            McpError::internal_error(format!("Transport error: {:?}", kind), None)
82        }
83    }
84}
85
86// =============================================================================
87// Plexus RPC MCP Bridge
88// =============================================================================
89
90/// MCP handler that bridges to Plexus RPC server
91#[derive(Clone)]
92pub struct PlexusMcpBridge {
93    hub: Arc<DynamicHub>,
94}
95
96impl PlexusMcpBridge {
97    pub fn new(hub: Arc<DynamicHub>) -> Self {
98        Self { hub }
99    }
100}
101
102impl ServerHandler for PlexusMcpBridge {
103    fn get_info(&self) -> ServerInfo {
104        ServerInfo {
105            protocol_version: ProtocolVersion::LATEST,
106            capabilities: ServerCapabilities::builder()
107                .enable_tools()
108                .enable_logging()
109                .build(),
110            server_info: Implementation::from_build_env(),
111            instructions: Some(
112                "Plexus MCP server - provides access to all registered activations.".into(),
113            ),
114        }
115    }
116
117    async fn list_tools(
118        &self,
119        _request: Option<PaginatedRequestParam>,
120        _ctx: RequestContext<RoleServer>,
121    ) -> Result<ListToolsResult, McpError> {
122        let schemas = self.hub.list_plugin_schemas();
123        let tools = schemas_to_rmcp_tools(schemas);
124
125        tracing::debug!("Listing {} tools", tools.len());
126
127        Ok(ListToolsResult {
128            tools,
129            next_cursor: None,
130            meta: None,
131        })
132    }
133
134    async fn call_tool(
135        &self,
136        request: CallToolRequestParam,
137        ctx: RequestContext<RoleServer>,
138    ) -> Result<CallToolResult, McpError> {
139        let method_name = &request.name;
140        let arguments = request
141            .arguments
142            .map(serde_json::Value::Object)
143            .unwrap_or(json!({}));
144
145        tracing::debug!("Calling tool: {} with args: {:?}", method_name, arguments);
146
147        // Get progress token if provided
148        let progress_token = ctx.meta.get_progress_token();
149
150        // Logger name: plexus.namespace.method (e.g., plexus.bash.execute)
151        let logger = format!("plexus.{}", method_name);
152
153        // Call Plexus RPC hub and get stream
154        let stream = self
155            .hub
156            .route(method_name, arguments)
157            .await
158            .map_err(plexus_to_mcp_error)?;
159
160        // Stream events via notifications AND buffer for final result
161        let mut had_error = false;
162        let mut buffered_data: Vec<serde_json::Value> = Vec::new();
163        let mut error_messages: Vec<String> = Vec::new();
164
165        tokio::pin!(stream);
166        while let Some(item) = stream.next().await {
167            // Check cancellation on each iteration
168            if ctx.ct.is_cancelled() {
169                return Err(McpError::internal_error("Cancelled", None));
170            }
171
172            match &item {
173                PlexusStreamItem::Progress {
174                    message,
175                    percentage,
176                    ..
177                } => {
178                    // Only send progress if client provided token
179                    if let Some(ref token) = progress_token {
180                        let _ = ctx
181                            .peer
182                            .notify_progress(ProgressNotificationParam {
183                                progress_token: token.clone(),
184                                progress: percentage.unwrap_or(0.0) as f64,
185                                total: None,
186                                message: Some(message.clone()),
187                            })
188                            .await;
189                    }
190                }
191
192                PlexusStreamItem::Data {
193                    content, content_type, ..
194                } => {
195                    // Buffer data for final result
196                    buffered_data.push(content.clone());
197
198                    // Also stream via notifications for real-time consumers
199                    let _ = ctx
200                        .peer
201                        .notify_logging_message(LoggingMessageNotificationParam {
202                            level: LoggingLevel::Info,
203                            logger: Some(logger.clone()),
204                            data: json!({
205                                "type": "data",
206                                "content_type": content_type,
207                                "data": content,
208                            }),
209                        })
210                        .await;
211                }
212
213                PlexusStreamItem::Error {
214                    message, recoverable, ..
215                } => {
216                    // Buffer errors for final result
217                    error_messages.push(message.clone());
218
219                    let _ = ctx
220                        .peer
221                        .notify_logging_message(LoggingMessageNotificationParam {
222                            level: LoggingLevel::Error,
223                            logger: Some(logger.clone()),
224                            data: json!({
225                                "type": "error",
226                                "error": message,
227                                "recoverable": recoverable,
228                            }),
229                        })
230                        .await;
231
232                    if !recoverable {
233                        had_error = true;
234                    }
235                }
236
237                PlexusStreamItem::Done { .. } => {
238                    break;
239                }
240
241                PlexusStreamItem::Request {
242                    request_id,
243                    request_data,
244                    timeout_ms,
245                } => {
246                    // Send bidirectional request to client via logging notification
247                    // Client should respond via _plexus_respond tool
248                    let _ = ctx
249                        .peer
250                        .notify_logging_message(LoggingMessageNotificationParam {
251                            level: LoggingLevel::Info,
252                            logger: Some(logger.clone()),
253                            data: json!({
254                                "type": "request",
255                                "request_id": request_id,
256                                "request_data": request_data,
257                                "timeout_ms": timeout_ms,
258                            }),
259                        })
260                        .await;
261                }
262            }
263        }
264
265        // Return buffered data in the final result
266        if had_error {
267            let error_content = if error_messages.is_empty() {
268                "Stream completed with errors".to_string()
269            } else {
270                error_messages.join("\n")
271            };
272            Ok(CallToolResult::error(vec![Content::text(error_content)]))
273        } else {
274            // Convert buffered data to content
275            let text_content = if buffered_data.is_empty() {
276                "(no output)".to_string()
277            } else if buffered_data.len() == 1 {
278                // Single value - return as text if string, otherwise JSON
279                match &buffered_data[0] {
280                    serde_json::Value::String(s) => s.clone(),
281                    other => serde_json::to_string_pretty(other).unwrap_or_default(),
282                }
283            } else {
284                // Multiple values - join strings or return as JSON array
285                let all_strings = buffered_data.iter().all(|v| v.is_string());
286                if all_strings {
287                    buffered_data
288                        .iter()
289                        .filter_map(|v| v.as_str())
290                        .collect::<Vec<_>>()
291                        .join("")
292                } else {
293                    serde_json::to_string_pretty(&buffered_data).unwrap_or_default()
294                }
295            };
296
297            Ok(CallToolResult::success(vec![Content::text(text_content)]))
298        }
299    }
300}