mcpkit_rocket/
router.rs

1//! Router builder for MCP endpoints in Rocket.
2
3use crate::state::{HasServerInfo, McpState};
4use mcpkit_server::{PromptHandler, ResourceHandler, ServerHandler, ToolHandler};
5use rocket::fairing::{Fairing, Info, Kind};
6use rocket::http::Header;
7use rocket::{Build, Request, Response, Rocket};
8
9/// Builder for MCP Rocket routers.
10///
11/// Creates a pre-configured Rocket with MCP endpoints.
12///
13/// # Example
14///
15/// ```ignore
16/// use mcpkit_rocket::McpRouter;
17///
18/// struct MyHandler;
19///
20/// // Basic usage - launch the server
21/// #[rocket::main]
22/// async fn main() -> Result<(), rocket::Error> {
23///     McpRouter::new(MyHandler)
24///         .launch()
25///         .await?;
26///     Ok(())
27/// }
28/// ```
29pub struct McpRouter<H> {
30    state: McpState<H>,
31    enable_cors: bool,
32}
33
34impl<H> McpRouter<H>
35where
36    H: ServerHandler
37        + ToolHandler
38        + ResourceHandler
39        + PromptHandler
40        + HasServerInfo
41        + Send
42        + Sync
43        + 'static,
44{
45    /// Create a new MCP router with the given handler.
46    pub fn new(handler: H) -> Self {
47        Self {
48            state: McpState::new(handler),
49            enable_cors: false,
50        }
51    }
52
53    /// Enable CORS with permissive defaults.
54    #[must_use]
55    pub const fn with_cors(mut self) -> Self {
56        self.enable_cors = true;
57        self
58    }
59
60    /// Build a Rocket instance with MCP routes.
61    ///
62    /// Note: Due to Rocket's type system constraints, this method creates
63    /// routes that are specific to the handler type. Use the `create_routes!`
64    /// macro in your application to generate the routes.
65    #[must_use]
66    pub fn into_rocket(self) -> Rocket<Build> {
67        let mut rocket = rocket::build().manage(self.state);
68
69        if self.enable_cors {
70            rocket = rocket.attach(Cors);
71        }
72
73        rocket
74    }
75
76    /// Get the MCP state for use with custom route handlers.
77    #[must_use]
78    pub fn into_state(self) -> McpState<H> {
79        self.state
80    }
81
82    /// Launch the MCP server.
83    ///
84    /// This is a convenience method that provides a stdio-like experience.
85    /// Note: You'll need to mount the routes separately using macros.
86    pub async fn launch(self) -> Result<(), rocket::Error> {
87        let _ = self.into_rocket().launch().await?;
88        Ok(())
89    }
90}
91
92/// CORS fairing for permissive cross-origin requests.
93pub struct Cors;
94
95#[rocket::async_trait]
96impl Fairing for Cors {
97    fn info(&self) -> Info {
98        Info {
99            name: "CORS",
100            kind: Kind::Response,
101        }
102    }
103
104    async fn on_response<'r>(&self, _request: &'r Request<'_>, response: &mut Response<'r>) {
105        response.set_header(Header::new("Access-Control-Allow-Origin", "*"));
106        response.set_header(Header::new(
107            "Access-Control-Allow-Methods",
108            "GET, POST, OPTIONS",
109        ));
110        response.set_header(Header::new(
111            "Access-Control-Allow-Headers",
112            "Content-Type, mcp-protocol-version, mcp-session-id, last-event-id",
113        ));
114        response.set_header(Header::new(
115            "Access-Control-Expose-Headers",
116            "mcp-session-id",
117        ));
118    }
119}
120
121/// Create MCP route handlers for a specific handler type.
122///
123/// This macro generates the Rocket route handlers for your MCP server.
124/// Due to Rocket's type system constraints, the route handlers must be
125/// generated at compile time for your specific handler type.
126///
127/// # Example
128///
129/// ```ignore
130/// use mcpkit_rocket::{McpRouter, create_mcp_routes};
131///
132/// struct MyHandler;
133/// // ... implement ServerHandler, ToolHandler, etc. for MyHandler
134///
135/// // Generate the routes
136/// create_mcp_routes!(MyHandler);
137///
138/// #[rocket::main]
139/// async fn main() -> Result<(), rocket::Error> {
140///     let state = McpRouter::new(MyHandler).into_state();
141///
142///     rocket::build()
143///         .manage(state)
144///         .mount("/", routes![mcp_post, mcp_sse])
145///         .launch()
146///         .await?;
147///
148///     Ok(())
149/// }
150/// ```
151#[macro_export]
152macro_rules! create_mcp_routes {
153    ($handler_type:ty) => {
154        #[rocket::post("/mcp", data = "<body>")]
155        async fn mcp_post(
156            state: &::rocket::State<$crate::McpState<$handler_type>>,
157            version: $crate::handler::ProtocolVersionHeader,
158            session: $crate::handler::SessionIdHeader,
159            body: String,
160        ) -> $crate::handler::McpResponse {
161            $crate::handler::handle_mcp_post(state.inner(), version.0.as_deref(), session.0, &body)
162                .await
163        }
164
165        #[rocket::get("/mcp/sse")]
166        fn mcp_sse(
167            state: &::rocket::State<$crate::McpState<$handler_type>>,
168            session: $crate::handler::SessionIdHeader,
169        ) -> ::rocket::response::stream::EventStream![] {
170            $crate::handler::handle_sse(state.inner(), session.0)
171        }
172    };
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178    use mcpkit_core::capability::{ServerCapabilities, ServerInfo};
179    use mcpkit_core::error::McpError;
180    use mcpkit_core::types::{
181        GetPromptResult, Prompt, Resource, ResourceContents, Tool, ToolOutput,
182    };
183    use mcpkit_server::ServerHandler;
184    use mcpkit_server::context::Context;
185
186    struct TestHandler;
187
188    impl ServerHandler for TestHandler {
189        fn server_info(&self) -> ServerInfo {
190            ServerInfo {
191                name: "test-server".to_string(),
192                version: "1.0.0".to_string(),
193                protocol_version: None,
194            }
195        }
196
197        fn capabilities(&self) -> ServerCapabilities {
198            ServerCapabilities::new()
199                .with_tools()
200                .with_resources()
201                .with_prompts()
202        }
203    }
204
205    impl ToolHandler for TestHandler {
206        async fn list_tools(&self, _ctx: &Context<'_>) -> Result<Vec<Tool>, McpError> {
207            Ok(vec![])
208        }
209
210        async fn call_tool(
211            &self,
212            _name: &str,
213            _args: serde_json::Value,
214            _ctx: &Context<'_>,
215        ) -> Result<ToolOutput, McpError> {
216            Ok(ToolOutput::text("test"))
217        }
218    }
219
220    impl ResourceHandler for TestHandler {
221        async fn list_resources(&self, _ctx: &Context<'_>) -> Result<Vec<Resource>, McpError> {
222            Ok(vec![])
223        }
224
225        async fn read_resource(
226            &self,
227            uri: &str,
228            _ctx: &Context<'_>,
229        ) -> Result<Vec<ResourceContents>, McpError> {
230            Ok(vec![ResourceContents::text(uri, "test")])
231        }
232    }
233
234    impl PromptHandler for TestHandler {
235        async fn list_prompts(&self, _ctx: &Context<'_>) -> Result<Vec<Prompt>, McpError> {
236            Ok(vec![])
237        }
238
239        async fn get_prompt(
240            &self,
241            _name: &str,
242            _args: Option<serde_json::Map<String, serde_json::Value>>,
243            _ctx: &Context<'_>,
244        ) -> Result<GetPromptResult, McpError> {
245            Ok(GetPromptResult {
246                description: Some("Test prompt".to_string()),
247                messages: vec![],
248            })
249        }
250    }
251
252    #[test]
253    fn test_router_builder() {
254        let router = McpRouter::new(TestHandler).with_cors();
255
256        // Router should be created without panicking
257        let _ = router.into_rocket();
258    }
259
260    #[test]
261    fn test_state_extraction() {
262        let router = McpRouter::new(TestHandler);
263        let state = router.into_state();
264
265        assert_eq!(state.server_info.name, "test-server");
266        assert_eq!(state.server_info.version, "1.0.0");
267    }
268}