use crate::handler::{handle_mcp_post, handle_sse, with_protocol_version, with_session_id};
use crate::state::{HasServerInfo, McpState};
use mcpkit_server::{PromptHandler, ResourceHandler, ServerHandler, ToolHandler};
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use warp::Filter;
pub struct McpRouter<H> {
state: Arc<McpState<H>>,
enable_cors: bool,
}
impl<H> McpRouter<H>
where
H: ServerHandler
+ ToolHandler
+ ResourceHandler
+ PromptHandler
+ HasServerInfo
+ Send
+ Sync
+ 'static,
{
pub fn new(handler: H) -> Self {
Self {
state: Arc::new(McpState::new(handler)),
enable_cors: false,
}
}
#[must_use]
pub const fn with_cors(mut self) -> Self {
self.enable_cors = true;
self
}
#[must_use]
pub fn into_filter(
self,
) -> impl Filter<Extract = (impl warp::Reply,), Error = warp::Rejection> + Clone {
let state = self.state;
let post_state = state.clone();
let mcp_post = warp::path("mcp")
.and(warp::post())
.and(with_state(post_state))
.and(with_protocol_version())
.and(with_session_id())
.and(warp::body::content_length_limit(1024 * 1024)) .and(warp::body::bytes())
.and_then(
|state: Arc<McpState<H>>,
version: Option<String>,
session_id: Option<String>,
bytes: bytes::Bytes| async move {
let body = String::from_utf8_lossy(&bytes).to_string();
handle_mcp_post(state, version, session_id, body).await
},
);
let sse_state = state;
let mcp_sse = warp::path("mcp")
.and(warp::path("sse"))
.and(warp::get())
.and(with_state(sse_state))
.and(with_session_id())
.map(|state: Arc<McpState<H>>, session_id: Option<String>| {
handle_sse(state, session_id)
});
mcp_post.or(mcp_sse).with(
warp::cors()
.allow_any_origin()
.allow_methods(vec!["GET", "POST", "OPTIONS"])
.allow_headers(vec![
"content-type",
"mcp-protocol-version",
"mcp-session-id",
"last-event-id",
])
.expose_headers(vec!["mcp-session-id"]),
)
}
#[must_use]
pub fn into_filter_without_cors(
self,
) -> impl Filter<Extract = (impl warp::Reply,), Error = warp::Rejection> + Clone {
let state = self.state;
let post_state = state.clone();
let mcp_post = warp::path("mcp")
.and(warp::post())
.and(with_state(post_state))
.and(with_protocol_version())
.and(with_session_id())
.and(warp::body::content_length_limit(1024 * 1024)) .and(warp::body::bytes())
.and_then(
|state: Arc<McpState<H>>,
version: Option<String>,
session_id: Option<String>,
bytes: bytes::Bytes| async move {
let body = String::from_utf8_lossy(&bytes).to_string();
handle_mcp_post(state, version, session_id, body).await
},
);
let sse_state = state;
let mcp_sse = warp::path("mcp")
.and(warp::path("sse"))
.and(warp::get())
.and(with_state(sse_state))
.and(with_session_id())
.map(|state: Arc<McpState<H>>, session_id: Option<String>| {
handle_sse(state, session_id)
});
mcp_post.or(mcp_sse)
}
pub async fn serve(self, addr: impl Into<SocketAddr>) {
let filter = self.into_filter();
warp::serve(filter).run(addr).await;
}
}
fn with_state<H: Send + Sync + 'static>(
state: Arc<McpState<H>>,
) -> impl Filter<Extract = (Arc<McpState<H>>,), Error = Infallible> + Clone {
warp::any().map(move || state.clone())
}
#[cfg(test)]
mod tests {
use super::*;
use mcpkit_core::capability::{ServerCapabilities, ServerInfo};
use mcpkit_core::error::McpError;
use mcpkit_core::types::{
GetPromptResult, Prompt, Resource, ResourceContents, Tool, ToolOutput,
};
use mcpkit_server::ServerHandler;
use mcpkit_server::context::Context;
struct TestHandler;
impl ServerHandler for TestHandler {
fn server_info(&self) -> ServerInfo {
ServerInfo {
name: "test-server".to_string(),
version: "1.0.0".to_string(),
protocol_version: None,
}
}
fn capabilities(&self) -> ServerCapabilities {
ServerCapabilities::new()
.with_tools()
.with_resources()
.with_prompts()
}
}
impl ToolHandler for TestHandler {
async fn list_tools(&self, _ctx: &Context<'_>) -> Result<Vec<Tool>, McpError> {
Ok(vec![])
}
async fn call_tool(
&self,
_name: &str,
_args: serde_json::Value,
_ctx: &Context<'_>,
) -> Result<ToolOutput, McpError> {
Ok(ToolOutput::text("test"))
}
}
impl ResourceHandler for TestHandler {
async fn list_resources(&self, _ctx: &Context<'_>) -> Result<Vec<Resource>, McpError> {
Ok(vec![])
}
async fn read_resource(
&self,
uri: &str,
_ctx: &Context<'_>,
) -> Result<Vec<ResourceContents>, McpError> {
Ok(vec![ResourceContents::text(uri, "test")])
}
}
impl PromptHandler for TestHandler {
async fn list_prompts(&self, _ctx: &Context<'_>) -> Result<Vec<Prompt>, McpError> {
Ok(vec![])
}
async fn get_prompt(
&self,
_name: &str,
_args: Option<serde_json::Map<String, serde_json::Value>>,
_ctx: &Context<'_>,
) -> Result<GetPromptResult, McpError> {
Ok(GetPromptResult {
description: Some("Test prompt".to_string()),
messages: vec![],
})
}
}
#[test]
fn test_router_builder() {
let router = McpRouter::new(TestHandler).with_cors();
let _ = router.into_filter();
}
}