use crate::handler::{handle_mcp_post, handle_oauth_protected_resource, handle_sse};
use crate::state::{HasServerInfo, McpState, OAuthState};
use axum::Router;
use axum::routing::{get, post};
use mcpkit_core::auth::ProtectedResourceMetadata;
use mcpkit_server::{PromptHandler, ResourceHandler, ServerHandler, ToolHandler};
use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
pub struct McpRouter<H> {
state: McpState<H>,
enable_cors: bool,
enable_tracing: bool,
post_path: String,
sse_path: String,
oauth_metadata: Option<ProtectedResourceMetadata>,
}
impl<H> McpRouter<H>
where
H: ServerHandler
+ ToolHandler
+ ResourceHandler
+ PromptHandler
+ HasServerInfo
+ Send
+ Sync
+ 'static,
{
pub fn new(handler: H) -> Self {
Self {
state: McpState::new(handler),
enable_cors: false,
enable_tracing: false,
post_path: "/mcp".to_string(),
sse_path: "/mcp/sse".to_string(),
oauth_metadata: None,
}
}
#[must_use]
pub const fn with_cors(mut self) -> Self {
self.enable_cors = true;
self
}
#[must_use]
pub const fn with_tracing(mut self) -> Self {
self.enable_tracing = true;
self
}
#[must_use]
pub fn post_path(mut self, path: impl Into<String>) -> Self {
self.post_path = path.into();
self
}
#[must_use]
pub fn sse_path(mut self, path: impl Into<String>) -> Self {
self.sse_path = path.into();
self
}
#[must_use]
pub fn with_oauth(mut self, metadata: ProtectedResourceMetadata) -> Self {
self.oauth_metadata = Some(metadata);
self
}
pub fn into_router(self) -> Router {
let mut router = Router::new()
.route(&self.post_path, post(handle_mcp_post::<H>))
.route(&self.sse_path, get(handle_sse::<H>))
.with_state(self.state);
if let Some(metadata) = self.oauth_metadata {
let oauth_router = Router::new()
.route(
"/.well-known/oauth-protected-resource",
get(handle_oauth_protected_resource),
)
.with_state(OAuthState::new(metadata));
router = router.merge(oauth_router);
}
if self.enable_cors {
router = router.layer(
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any),
);
}
if self.enable_tracing {
router = router.layer(TraceLayer::new_for_http());
}
router
}
pub async fn serve(self, addr: &str) -> std::io::Result<()> {
let router = self.into_router();
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, router)
.await
.map_err(std::io::Error::other)
}
}
#[cfg(test)]
mod tests {
use super::*;
use mcpkit_core::capability::{ServerCapabilities, ServerInfo};
use mcpkit_core::error::McpError;
use mcpkit_core::types::{
GetPromptResult, Prompt, Resource, ResourceContents, Tool, ToolOutput,
};
use mcpkit_server::ServerHandler;
use mcpkit_server::context::Context;
struct TestHandler;
impl ServerHandler for TestHandler {
fn server_info(&self) -> ServerInfo {
ServerInfo {
name: "test-server".to_string(),
version: "1.0.0".to_string(),
protocol_version: None,
}
}
fn capabilities(&self) -> ServerCapabilities {
ServerCapabilities::new()
.with_tools()
.with_resources()
.with_prompts()
}
}
impl ToolHandler for TestHandler {
async fn list_tools(&self, _ctx: &Context<'_>) -> Result<Vec<Tool>, McpError> {
Ok(vec![])
}
async fn call_tool(
&self,
_name: &str,
_args: serde_json::Value,
_ctx: &Context<'_>,
) -> Result<ToolOutput, McpError> {
Ok(ToolOutput::text("test"))
}
}
impl ResourceHandler for TestHandler {
async fn list_resources(&self, _ctx: &Context<'_>) -> Result<Vec<Resource>, McpError> {
Ok(vec![])
}
async fn read_resource(
&self,
uri: &str,
_ctx: &Context<'_>,
) -> Result<Vec<ResourceContents>, McpError> {
Ok(vec![ResourceContents::text(uri, "test")])
}
}
impl PromptHandler for TestHandler {
async fn list_prompts(&self, _ctx: &Context<'_>) -> Result<Vec<Prompt>, McpError> {
Ok(vec![])
}
async fn get_prompt(
&self,
_name: &str,
_args: Option<serde_json::Map<String, serde_json::Value>>,
_ctx: &Context<'_>,
) -> Result<GetPromptResult, McpError> {
Ok(GetPromptResult {
description: Some("Test prompt".to_string()),
messages: vec![],
})
}
}
#[test]
fn test_router_builder() {
let router = McpRouter::new(TestHandler)
.with_cors()
.with_tracing()
.post_path("/api/mcp")
.sse_path("/api/sse")
.into_router();
let _ = router;
}
}