Skip to main content

plexus_substrate/
mcp_bridge.rs

1//! MCP server bridge using rmcp with Plexus RPC backend
2//!
3//! This module implements the MCP protocol using the rmcp crate,
4//! bridging MCP tool calls to Plexus RPC activation methods.
5
6use std::sync::Arc;
7
8use futures::StreamExt;
9use rmcp::{
10    ErrorData as McpError,
11    ServerHandler,
12    model::*,
13    service::{RequestContext, RoleServer},
14};
15use serde_json::json;
16
17use crate::plexus::{DynamicHub, PlexusError, PluginSchema};
18use crate::plexus::types::PlexusStreamItem;
19
20// =============================================================================
21// Schema Transformation
22// =============================================================================
23
24/// Convert Plexus RPC activation schemas to rmcp Tool format
25///
26/// MCP requires all tool inputSchema to have "type": "object" at root.
27/// schemars may produce schemas without this (e.g., for unit types).
28fn schemas_to_rmcp_tools(schemas: Vec<PluginSchema>) -> Vec<Tool> {
29    schemas
30        .into_iter()
31        .flat_map(|activation| {
32            let namespace = activation.namespace.clone();
33            activation.methods.into_iter().map(move |method| {
34                let name = format!("{}.{}", namespace, method.name);
35                let description = method.description.clone();
36
37                // Convert schemars::Schema to JSON, ensure "type": "object" exists
38                let input_schema = method
39                    .params
40                    .and_then(|s| serde_json::to_value(s).ok())
41                    .and_then(|v| v.as_object().cloned())
42                    .map(|mut obj| {
43                        // MCP requires "type": "object" at schema root
44                        if !obj.contains_key("type") {
45                            obj.insert("type".to_string(), json!("object"));
46                        }
47                        Arc::new(obj)
48                    })
49                    .unwrap_or_else(|| {
50                        // Empty params = empty object schema
51                        Arc::new(serde_json::Map::from_iter([
52                            ("type".to_string(), json!("object")),
53                        ]))
54                    });
55
56                Tool::new(name, description, input_schema)
57            })
58        })
59        .collect()
60}
61
62// =============================================================================
63// Error Mapping
64// =============================================================================
65
66/// Convert PlexusError to McpError
67fn plexus_to_mcp_error(e: PlexusError) -> McpError {
68    match e {
69        PlexusError::ActivationNotFound(name) => {
70            McpError::invalid_params(format!("Unknown activation: {}", name), None)
71        }
72        PlexusError::MethodNotFound { activation, method } => {
73            McpError::invalid_params(format!("Unknown method: {}.{}", activation, method), None)
74        }
75        PlexusError::InvalidParams(reason) => McpError::invalid_params(reason, None),
76        PlexusError::ExecutionError(error) => McpError::internal_error(error, None),
77        PlexusError::HandleNotSupported(activation) => {
78            McpError::invalid_params(format!("Handle resolution not supported: {}", activation), None)
79        }
80        PlexusError::TransportError(kind) => {
81            McpError::internal_error(format!("Transport error: {:?}", kind), None)
82        }
83        PlexusError::Unauthenticated(reason) => {
84            McpError::invalid_params(format!("Unauthenticated: {}", reason), None)
85        }
86    }
87}
88
89// =============================================================================
90// Plexus RPC MCP Bridge
91// =============================================================================
92
93/// MCP handler that bridges to Plexus RPC server
94#[derive(Clone)]
95pub struct PlexusMcpBridge {
96    hub: Arc<DynamicHub>,
97}
98
99impl PlexusMcpBridge {
100    pub fn new(hub: Arc<DynamicHub>) -> Self {
101        Self { hub }
102    }
103}
104
105impl ServerHandler for PlexusMcpBridge {
106    fn get_info(&self) -> ServerInfo {
107        ServerInfo {
108            protocol_version: ProtocolVersion::LATEST,
109            capabilities: ServerCapabilities::builder()
110                .enable_tools()
111                .enable_logging()
112                .build(),
113            server_info: Implementation::from_build_env(),
114            instructions: Some(
115                "Plexus MCP server - provides access to all registered activations.".into(),
116            ),
117        }
118    }
119
120    async fn list_tools(
121        &self,
122        _request: Option<PaginatedRequestParam>,
123        _ctx: RequestContext<RoleServer>,
124    ) -> Result<ListToolsResult, McpError> {
125        let schemas = self.hub.list_plugin_schemas();
126        let tools = schemas_to_rmcp_tools(schemas);
127
128        tracing::debug!("Listing {} tools", tools.len());
129
130        Ok(ListToolsResult {
131            tools,
132            next_cursor: None,
133            meta: None,
134        })
135    }
136
137    async fn call_tool(
138        &self,
139        request: CallToolRequestParam,
140        ctx: RequestContext<RoleServer>,
141    ) -> Result<CallToolResult, McpError> {
142        let method_name = &request.name;
143        let arguments = request
144            .arguments
145            .map(serde_json::Value::Object)
146            .unwrap_or(json!({}));
147
148        tracing::debug!("Calling tool: {} with args: {:?}", method_name, arguments);
149
150        // Get progress token if provided
151        let progress_token = ctx.meta.get_progress_token();
152
153        // Logger name: plexus.namespace.method (e.g., plexus.bash.execute)
154        let logger = format!("plexus.{}", method_name);
155
156        // Call Plexus RPC hub and get stream
157        let stream = self
158            .hub
159            .route(method_name, arguments, None)
160            .await
161            .map_err(plexus_to_mcp_error)?;
162
163        // Stream events via notifications AND buffer for final result
164        let mut had_error = false;
165        let mut buffered_data: Vec<serde_json::Value> = Vec::new();
166        let mut error_messages: Vec<String> = Vec::new();
167
168        tokio::pin!(stream);
169        while let Some(item) = stream.next().await {
170            // Check cancellation on each iteration
171            if ctx.ct.is_cancelled() {
172                return Err(McpError::internal_error("Cancelled", None));
173            }
174
175            match &item {
176                PlexusStreamItem::Progress {
177                    message,
178                    percentage,
179                    ..
180                } => {
181                    // Only send progress if client provided token
182                    if let Some(ref token) = progress_token {
183                        let _ = ctx
184                            .peer
185                            .notify_progress(ProgressNotificationParam {
186                                progress_token: token.clone(),
187                                progress: percentage.unwrap_or(0.0) as f64,
188                                total: None,
189                                message: Some(message.clone()),
190                            })
191                            .await;
192                    }
193                }
194
195                PlexusStreamItem::Data {
196                    content, content_type, ..
197                } => {
198                    // Buffer data for final result
199                    buffered_data.push(content.clone());
200
201                    // Also stream via notifications for real-time consumers
202                    let _ = ctx
203                        .peer
204                        .notify_logging_message(LoggingMessageNotificationParam {
205                            level: LoggingLevel::Info,
206                            logger: Some(logger.clone()),
207                            data: json!({
208                                "type": "data",
209                                "content_type": content_type,
210                                "data": content,
211                            }),
212                        })
213                        .await;
214                }
215
216                PlexusStreamItem::Error {
217                    message, recoverable, ..
218                } => {
219                    // Buffer errors for final result
220                    error_messages.push(message.clone());
221
222                    let _ = ctx
223                        .peer
224                        .notify_logging_message(LoggingMessageNotificationParam {
225                            level: LoggingLevel::Error,
226                            logger: Some(logger.clone()),
227                            data: json!({
228                                "type": "error",
229                                "error": message,
230                                "recoverable": recoverable,
231                            }),
232                        })
233                        .await;
234
235                    if !recoverable {
236                        had_error = true;
237                    }
238                }
239
240                PlexusStreamItem::Done { .. } => {
241                    break;
242                }
243
244                PlexusStreamItem::Request {
245                    request_id,
246                    request_data,
247                    timeout_ms,
248                } => {
249                    // Send bidirectional request to client via logging notification
250                    // Client should respond via _plexus_respond tool
251                    let _ = ctx
252                        .peer
253                        .notify_logging_message(LoggingMessageNotificationParam {
254                            level: LoggingLevel::Info,
255                            logger: Some(logger.clone()),
256                            data: json!({
257                                "type": "request",
258                                "request_id": request_id,
259                                "request_data": request_data,
260                                "timeout_ms": timeout_ms,
261                            }),
262                        })
263                        .await;
264                }
265            }
266        }
267
268        // Return buffered data in the final result
269        if had_error {
270            let error_content = if error_messages.is_empty() {
271                "Stream completed with errors".to_string()
272            } else {
273                error_messages.join("\n")
274            };
275            Ok(CallToolResult::error(vec![Content::text(error_content)]))
276        } else {
277            // Convert buffered data to content
278            let text_content = if buffered_data.is_empty() {
279                "(no output)".to_string()
280            } else if buffered_data.len() == 1 {
281                // Single value - return as text if string, otherwise JSON
282                match &buffered_data[0] {
283                    serde_json::Value::String(s) => s.clone(),
284                    other => serde_json::to_string_pretty(other).unwrap_or_default(),
285                }
286            } else {
287                // Multiple values - join strings or return as JSON array
288                let all_strings = buffered_data.iter().all(|v| v.is_string());
289                if all_strings {
290                    buffered_data
291                        .iter()
292                        .filter_map(|v| v.as_str())
293                        .collect::<Vec<_>>()
294                        .join("")
295                } else {
296                    serde_json::to_string_pretty(&buffered_data).unwrap_or_default()
297                }
298            };
299
300            Ok(CallToolResult::success(vec![Content::text(text_content)]))
301        }
302    }
303}