1pub mod auth;
25pub mod client;
26pub mod error;
27pub mod handler;
28pub mod server;
29pub mod types;
30
31#[cfg(feature = "proxy")]
32pub mod proxy;
33
34pub use error::*;
35pub use types::*;
36
37#[cfg(feature = "client")]
38pub use client::*;
39
40#[cfg(feature = "server")]
41pub use server::*;
42
43pub use auth::*;
44
45#[cfg(feature = "proxy")]
46pub use proxy::*;
47
48use async_trait::async_trait;
49use protocol_transport_core::{
50 AsyncProtocolHandler, ProtocolError, ProtocolHandler, UniversalRequest, UniversalResponse,
51};
52use std::collections::HashMap;
53
54pub const MCP_PROTOCOL_VERSION: &str = "2025-06-18";
56
57pub const JSONRPC_VERSION: &str = "2.0";
59
60pub struct McpProtocolHandler {
62 capabilities: Option<ServerCapabilities>,
64 auth_handler: Option<Box<dyn AuthHandler>>,
66 tool_provider: Option<Box<dyn ToolProvider>>,
68 query_mode: QueryMode,
70}
71
72#[async_trait]
74pub trait ToolProvider: Send + Sync {
75 fn list_tools(&self) -> Result<Vec<Tool>, ProtocolError>;
77
78 async fn call_tool(
80 &self,
81 name: &str,
82 arguments: Option<serde_json::Value>,
83 ) -> Result<CallToolResult, ProtocolError>;
84}
85
86#[derive(Debug, Clone)]
88pub enum QueryMode {
89 Single,
91 Aggregate,
93}
94
95impl Default for QueryMode {
96 fn default() -> Self {
97 QueryMode::Single
98 }
99}
100
101impl McpProtocolHandler {
102 pub fn new() -> Self {
104 Self {
105 capabilities: None,
106 auth_handler: None,
107 tool_provider: None,
108 query_mode: QueryMode::Single,
109 }
110 }
111
112 pub fn with_capabilities(mut self, capabilities: ServerCapabilities) -> Self {
114 self.capabilities = Some(capabilities);
115 self
116 }
117
118 pub fn with_auth_handler<H: AuthHandler + 'static>(mut self, handler: H) -> Self {
120 self.auth_handler = Some(Box::new(handler));
121 self
122 }
123
124 pub fn with_tool_provider<P: ToolProvider + 'static>(mut self, provider: P) -> Self {
126 self.tool_provider = Some(Box::new(provider));
127 self
128 }
129
130 pub fn with_query_mode(mut self, query_mode: QueryMode) -> Self {
132 self.query_mode = query_mode;
133 self
134 }
135
136 async fn handle_mcp_method(
138 &self,
139 method: &str,
140 params: serde_json::Value,
141 id: Option<serde_json::Value>,
142 ) -> Result<JsonRpcResponse, ProtocolError> {
143 match method {
144 "initialize" => self.handle_initialize(params, id),
145 "tools/list" => self.handle_list_tools(params, id),
146 "tools/call" => self.handle_call_tool(params, id).await,
147 _ => Ok(JsonRpcResponse::error(
148 id,
149 JsonRpcError::method_not_found(&format!("Method '{}' not found", method)),
150 )),
151 }
152 }
153
154 fn handle_initialize(
155 &self,
156 params: serde_json::Value,
157 id: Option<serde_json::Value>,
158 ) -> Result<JsonRpcResponse, ProtocolError> {
159 let _init_request: InitializeRequest = serde_json::from_value(params)
160 .map_err(|e| ProtocolError::Parsing(format!("Invalid initialize request: {}", e)))?;
161
162 let result = InitializeResult {
163 protocol_version: MCP_PROTOCOL_VERSION.to_string(),
164 capabilities: self.capabilities.clone().unwrap_or_default(),
165 server_info: ServerInfo {
166 name: "promptfleet-mcp-server".to_string(),
167 version: env!("CARGO_PKG_VERSION").to_string(),
168 description: Some("PromptFleet MCP Server".to_string()),
169 },
170 };
171
172 Ok(JsonRpcResponse::success(id, serde_json::to_value(result)?))
173 }
174
175 fn handle_list_tools(
176 &self,
177 _params: serde_json::Value,
178 id: Option<serde_json::Value>,
179 ) -> Result<JsonRpcResponse, ProtocolError> {
180 let tools = match &self.tool_provider {
181 Some(provider) => provider.list_tools()?,
182 None => vec![], };
184
185 let result = ListToolsResult { tools };
186 Ok(JsonRpcResponse::success(id, serde_json::to_value(result)?))
187 }
188
189 async fn handle_call_tool(
190 &self,
191 params: serde_json::Value,
192 id: Option<serde_json::Value>,
193 ) -> Result<JsonRpcResponse, ProtocolError> {
194 let call_request: CallToolRequest = serde_json::from_value(params)
195 .map_err(|e| ProtocolError::Parsing(format!("Invalid call_tool request: {}", e)))?;
196
197 let result = match &self.tool_provider {
198 Some(provider) => {
199 provider
200 .call_tool(&call_request.name, call_request.arguments)
201 .await?
202 }
203 None => CallToolResult {
204 content: vec![Content::text("No tool provider configured")],
205 is_error: Some(true),
206 },
207 };
208
209 Ok(JsonRpcResponse::success(id, serde_json::to_value(result)?))
210 }
211}
212
213impl ProtocolHandler for McpProtocolHandler {
214 type Request = JsonRpcRequest;
215 type Response = JsonRpcResponse;
216 type Error = ProtocolError;
217
218 fn protocol_name(&self) -> &'static str {
219 "MCP"
220 }
221
222 fn encode_request(&self, request: &Self::Request) -> Result<UniversalRequest, Self::Error> {
223 let body = serde_json::to_vec(request)?;
224 let mut headers = HashMap::new();
225 headers.insert("content-type".to_string(), "application/json".to_string());
226 headers.insert(
227 "accept".to_string(),
228 "application/json, text/event-stream".to_string(),
229 );
230 headers.insert("x-protocol".to_string(), "MCP".to_string());
231
232 if let Some(id) = &request.id {
233 headers.insert("x-correlation-id".to_string(), id.to_string());
234 }
235
236 Ok(UniversalRequest {
237 method: "POST".to_string(),
238 uri: "/mcp/rpc".to_string(),
239 headers,
240 body,
241 protocol: "MCP".to_string(),
242 correlation_id: request
243 .id
244 .as_ref()
245 .map(|id| id.to_string())
246 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
247 })
248 }
249
250 fn decode_request(&self, universal: &UniversalRequest) -> Result<Self::Request, Self::Error> {
251 let request: JsonRpcRequest = serde_json::from_slice(&universal.body)?;
252 Ok(request)
253 }
254
255 fn encode_response(&self, response: &Self::Response) -> Result<UniversalResponse, Self::Error> {
256 let body = serde_json::to_vec(response)?;
257 let mut headers = HashMap::new();
258 headers.insert("content-type".to_string(), "application/json".to_string());
259 headers.insert("x-protocol".to_string(), "MCP".to_string());
260
261 Ok(UniversalResponse {
262 status: 200,
263 headers,
264 body,
265 protocol: "MCP".to_string(),
266 correlation_id: response
267 .id
268 .as_ref()
269 .map(|id| id.to_string())
270 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
271 })
272 }
273
274 fn decode_response(
275 &self,
276 universal: &UniversalResponse,
277 ) -> Result<Self::Response, Self::Error> {
278 let response: JsonRpcResponse = serde_json::from_slice(&universal.body)?;
279 Ok(response)
280 }
281}
282
283impl AsyncProtocolHandler for McpProtocolHandler {
284 fn protocol_name(&self) -> &'static str {
285 "MCP"
286 }
287
288 fn handle_request_sync(
289 &self,
290 request: UniversalRequest,
291 ) -> Result<UniversalResponse, ProtocolError> {
292 let body_str = String::from_utf8(request.body)
294 .map_err(|e| ProtocolError::Parsing(format!("Invalid UTF-8 in request body: {}", e)))?;
295
296 let json_request: serde_json::Value = serde_json::from_str(&body_str)
297 .map_err(|e| ProtocolError::Parsing(format!("Invalid JSON in request body: {}", e)))?;
298
299 let method = json_request["method"]
301 .as_str()
302 .ok_or_else(|| ProtocolError::Parsing("Missing 'method' field".to_string()))?;
303 let params = json_request.get("params").cloned().unwrap_or_default();
304 let id = json_request.get("id").cloned();
305
306 #[cfg(not(target_arch = "wasm32"))]
307 {
308 let response = tokio::runtime::Handle::current()
309 .block_on(self.handle_mcp_method(method, params, id))
310 .map_err(|e| ProtocolError::internal_error(&format!("MCP error: {:?}", e)))?;
311
312 let response_body =
313 serde_json::to_string(&response).map_err(ProtocolError::Serialization)?;
314
315 Ok(UniversalResponse {
316 status: 200,
317 headers: [("content-type".to_string(), "application/json".to_string())]
318 .iter()
319 .cloned()
320 .collect(),
321 body: response_body.into_bytes(),
322 protocol: "MCP".to_string(),
323 correlation_id: request.correlation_id,
324 })
325 }
326
327 #[cfg(target_arch = "wasm32")]
328 Err(ProtocolError::internal_error(
329 "Sync MCP handler not supported in WASM; use async handler",
330 ))
331 }
332}
333
334impl Default for McpProtocolHandler {
335 fn default() -> Self {
336 Self::new()
337 }
338}
339
340pub fn create_mcp_handler() -> McpProtocolHandler {
342 McpProtocolHandler::new().with_capabilities(ServerCapabilities::default())
343}
344
345pub fn create_mcp_handler_with_capabilities(
347 capabilities: ServerCapabilities,
348) -> McpProtocolHandler {
349 McpProtocolHandler::new().with_capabilities(capabilities)
350}