1use crate::state::{HasServerInfo, McpState};
4use mcpkit_server::{PromptHandler, ResourceHandler, ServerHandler, ToolHandler};
5use rocket::fairing::{Fairing, Info, Kind};
6use rocket::http::Header;
7use rocket::{Build, Request, Response, Rocket};
8
9pub struct McpRouter<H> {
30 state: McpState<H>,
31 enable_cors: bool,
32}
33
34impl<H> McpRouter<H>
35where
36 H: ServerHandler
37 + ToolHandler
38 + ResourceHandler
39 + PromptHandler
40 + HasServerInfo
41 + Send
42 + Sync
43 + 'static,
44{
45 pub fn new(handler: H) -> Self {
47 Self {
48 state: McpState::new(handler),
49 enable_cors: false,
50 }
51 }
52
53 #[must_use]
55 pub const fn with_cors(mut self) -> Self {
56 self.enable_cors = true;
57 self
58 }
59
60 #[must_use]
66 pub fn into_rocket(self) -> Rocket<Build> {
67 let mut rocket = rocket::build().manage(self.state);
68
69 if self.enable_cors {
70 rocket = rocket.attach(Cors);
71 }
72
73 rocket
74 }
75
76 #[must_use]
78 pub fn into_state(self) -> McpState<H> {
79 self.state
80 }
81
82 pub async fn launch(self) -> Result<(), rocket::Error> {
87 let _ = self.into_rocket().launch().await?;
88 Ok(())
89 }
90}
91
92pub struct Cors;
94
95#[rocket::async_trait]
96impl Fairing for Cors {
97 fn info(&self) -> Info {
98 Info {
99 name: "CORS",
100 kind: Kind::Response,
101 }
102 }
103
104 async fn on_response<'r>(&self, _request: &'r Request<'_>, response: &mut Response<'r>) {
105 response.set_header(Header::new("Access-Control-Allow-Origin", "*"));
106 response.set_header(Header::new(
107 "Access-Control-Allow-Methods",
108 "GET, POST, OPTIONS",
109 ));
110 response.set_header(Header::new(
111 "Access-Control-Allow-Headers",
112 "Content-Type, mcp-protocol-version, mcp-session-id, last-event-id",
113 ));
114 response.set_header(Header::new(
115 "Access-Control-Expose-Headers",
116 "mcp-session-id",
117 ));
118 }
119}
120
121#[macro_export]
152macro_rules! create_mcp_routes {
153 ($handler_type:ty) => {
154 #[rocket::post("/mcp", data = "<body>")]
155 async fn mcp_post(
156 state: &::rocket::State<$crate::McpState<$handler_type>>,
157 version: $crate::handler::ProtocolVersionHeader,
158 session: $crate::handler::SessionIdHeader,
159 body: String,
160 ) -> $crate::handler::McpResponse {
161 $crate::handler::handle_mcp_post(state.inner(), version.0.as_deref(), session.0, &body)
162 .await
163 }
164
165 #[rocket::get("/mcp/sse")]
166 fn mcp_sse(
167 state: &::rocket::State<$crate::McpState<$handler_type>>,
168 session: $crate::handler::SessionIdHeader,
169 ) -> ::rocket::response::stream::EventStream![] {
170 $crate::handler::handle_sse(state.inner(), session.0)
171 }
172 };
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178 use mcpkit_core::capability::{ServerCapabilities, ServerInfo};
179 use mcpkit_core::error::McpError;
180 use mcpkit_core::types::{
181 GetPromptResult, Prompt, Resource, ResourceContents, Tool, ToolOutput,
182 };
183 use mcpkit_server::ServerHandler;
184 use mcpkit_server::context::Context;
185
186 struct TestHandler;
187
188 impl ServerHandler for TestHandler {
189 fn server_info(&self) -> ServerInfo {
190 ServerInfo {
191 name: "test-server".to_string(),
192 version: "1.0.0".to_string(),
193 protocol_version: None,
194 }
195 }
196
197 fn capabilities(&self) -> ServerCapabilities {
198 ServerCapabilities::new()
199 .with_tools()
200 .with_resources()
201 .with_prompts()
202 }
203 }
204
205 impl ToolHandler for TestHandler {
206 async fn list_tools(&self, _ctx: &Context<'_>) -> Result<Vec<Tool>, McpError> {
207 Ok(vec![])
208 }
209
210 async fn call_tool(
211 &self,
212 _name: &str,
213 _args: serde_json::Value,
214 _ctx: &Context<'_>,
215 ) -> Result<ToolOutput, McpError> {
216 Ok(ToolOutput::text("test"))
217 }
218 }
219
220 impl ResourceHandler for TestHandler {
221 async fn list_resources(&self, _ctx: &Context<'_>) -> Result<Vec<Resource>, McpError> {
222 Ok(vec![])
223 }
224
225 async fn read_resource(
226 &self,
227 uri: &str,
228 _ctx: &Context<'_>,
229 ) -> Result<Vec<ResourceContents>, McpError> {
230 Ok(vec![ResourceContents::text(uri, "test")])
231 }
232 }
233
234 impl PromptHandler for TestHandler {
235 async fn list_prompts(&self, _ctx: &Context<'_>) -> Result<Vec<Prompt>, McpError> {
236 Ok(vec![])
237 }
238
239 async fn get_prompt(
240 &self,
241 _name: &str,
242 _args: Option<serde_json::Map<String, serde_json::Value>>,
243 _ctx: &Context<'_>,
244 ) -> Result<GetPromptResult, McpError> {
245 Ok(GetPromptResult {
246 description: Some("Test prompt".to_string()),
247 messages: vec![],
248 })
249 }
250 }
251
252 #[test]
253 fn test_router_builder() {
254 let router = McpRouter::new(TestHandler).with_cors();
255
256 let _ = router.into_rocket();
258 }
259
260 #[test]
261 fn test_state_extraction() {
262 let router = McpRouter::new(TestHandler);
263 let state = router.into_state();
264
265 assert_eq!(state.server_info.name, "test-server");
266 assert_eq!(state.server_info.version, "1.0.0");
267 }
268}